407 lines
17 KiB
Python
407 lines
17 KiB
Python
import threading
|
|
from dataclasses import replace
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
from sqlalchemy import inspect
|
|
from sqlalchemy.engine import Connection, Engine
|
|
|
|
from ._sql import quote
|
|
from .cache import CacheManager, TableError
|
|
from .config import (
|
|
BACKUP_INTERVAL_SECONDS,
|
|
CACHE_DB_PATH,
|
|
FETCH_BATCH_SIZE,
|
|
IN_MEMORY,
|
|
REFRESH_INTERVAL_SECONDS,
|
|
SQL_DIALECT,
|
|
)
|
|
from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
|
|
from .exceptions import UndeclaredError
|
|
from .executor import QueryExecutor
|
|
from .parser import Params, ParsedQuery, parse
|
|
from .registry import ColumnRegistry
|
|
from .spec import TTL, TableSpec
|
|
from .stats import Stats, StatsCollector, TableState, TableStats
|
|
|
|
|
|
def _specs_to_config(
|
|
tables: list[TableSpec],
|
|
) -> tuple[
|
|
dict[str, DeltaConfig],
|
|
dict[str, int],
|
|
dict[str, list[str | list[str]]],
|
|
dict[str, list[str]],
|
|
dict[str, list[str] | None],
|
|
]:
|
|
"""Convert declarative ``TableSpec``s into the engine's internal config dicts.
|
|
|
|
Returns ``(delta, ttl, indexes, datetime_columns, declared)`` — the first four
|
|
mirror the legacy kwargs; ``declared`` maps each table to its allowed columns
|
|
(``None`` = whole table / any column) for fail-fast query checking.
|
|
"""
|
|
delta: dict[str, DeltaConfig] = {}
|
|
ttl: dict[str, int] = {}
|
|
indexes: dict[str, list[str | list[str]]] = {}
|
|
datetime_columns: dict[str, list[str]] = {}
|
|
declared: dict[str, list[str] | None] = {}
|
|
for spec in tables:
|
|
if spec.name in declared:
|
|
raise ValueError(f"Duplicate TableSpec for table {spec.name!r}.")
|
|
declared[spec.name] = list(spec.columns) if spec.columns is not None else None
|
|
if spec.indexes:
|
|
indexes[spec.name] = list(spec.indexes)
|
|
if spec.datetime_columns:
|
|
datetime_columns[spec.name] = list(spec.datetime_columns)
|
|
refresh = spec.refresh
|
|
if isinstance(refresh, TTL):
|
|
ttl[spec.name] = refresh.seconds
|
|
elif isinstance(refresh, DeltaConfig):
|
|
delta[spec.name] = refresh
|
|
return delta, ttl, indexes, datetime_columns, declared
|
|
|
|
|
|
class _LazySource:
|
|
"""A source connection opened on first ``execute`` and shared across one query.
|
|
|
|
Most queries are cache hits that never touch the source, so opening it (and
|
|
occupying a connection-pool slot) eagerly is wasteful. This proxy forwards
|
|
``execute`` to a real connection opened on demand, then released by ``close``.
|
|
"""
|
|
|
|
def __init__(self, source_engine: Engine) -> None:
|
|
self._source_engine = source_engine
|
|
self._sa_conn: Connection | None = None
|
|
self._raw: Any = None
|
|
|
|
def execute(self, *args: Any, **kwargs: Any) -> Any:
|
|
if self._raw is None:
|
|
self._sa_conn = self._source_engine.connect()
|
|
self._raw = self._sa_conn.connection.dbapi_connection
|
|
return self._raw.execute(*args, **kwargs)
|
|
|
|
def close(self) -> None:
|
|
if self._sa_conn is not None:
|
|
self._sa_conn.close()
|
|
self._sa_conn = None
|
|
self._raw = None
|
|
|
|
|
|
class CachingEngine:
|
|
"""Transparent SQLAlchemy-compatible cache layer."""
|
|
|
|
def __init__(
|
|
self,
|
|
source_engine: Engine,
|
|
delta: dict[str, DeltaConfig] | None = None,
|
|
ttl: dict[str, int] | None = None,
|
|
indexes: dict[str, list[str | list[str]]] | None = None,
|
|
in_memory: bool | None = None,
|
|
cache_db_path: str | Path | None = None,
|
|
backup_interval: int | None = None,
|
|
refresh_interval: int | None = None,
|
|
fetch_batch: int | None = None,
|
|
dialect: str | None = None,
|
|
pragmas: dict[str, str | int] | None = None,
|
|
datetime_columns: dict[str, list[str]] | None = None,
|
|
return_datetime: bool = True,
|
|
tables: list[TableSpec] | None = None,
|
|
blocking_startup_refresh: bool = False,
|
|
) -> None:
|
|
self._source_engine = source_engine
|
|
|
|
# Declarative mode: a list of TableSpecs is converted to the same internal
|
|
# config the legacy delta=/ttl=/indexes=/datetime_columns= kwargs produce,
|
|
# plus a declared-columns allowlist (for fail-fast) and preload set.
|
|
self._declared: dict[str, list[str] | None] | None = None
|
|
self._preload_specs: list[TableSpec] = []
|
|
if tables is not None:
|
|
if any(x is not None for x in (delta, ttl, indexes, datetime_columns)):
|
|
raise ValueError(
|
|
"Pass either tables=[TableSpec(...)] or the legacy "
|
|
"delta=/ttl=/indexes=/datetime_columns= kwargs, not both."
|
|
)
|
|
delta, ttl, indexes, datetime_columns, self._declared = _specs_to_config(tables)
|
|
self._preload_specs = [s for s in tables if s.preload]
|
|
|
|
use_memory = IN_MEMORY if in_memory is None else in_memory
|
|
self._dialect = dialect if dialect is not None else SQL_DIALECT
|
|
self._refresh_interval = (
|
|
refresh_interval if refresh_interval is not None else REFRESH_INTERVAL_SECONDS
|
|
)
|
|
self._cache = CacheManager(
|
|
Path(cache_db_path) if cache_db_path is not None else CACHE_DB_PATH,
|
|
backup_interval if backup_interval is not None else BACKUP_INTERVAL_SECONDS,
|
|
in_memory=use_memory,
|
|
dialect=self._dialect,
|
|
fetch_batch=fetch_batch if fetch_batch is not None else FETCH_BATCH_SIZE,
|
|
pragmas=pragmas,
|
|
datetime_columns=datetime_columns,
|
|
return_datetime=return_datetime,
|
|
)
|
|
self._registry = ColumnRegistry(self._cache.connection)
|
|
self._stats = StatsCollector()
|
|
self._delta = self._resolve_delta(delta or {})
|
|
self._ttl = dict(ttl or {})
|
|
self._index_columns = self._register_indexes(indexes or {})
|
|
self._refresher = DeltaRefresher(self._cache, self._delta)
|
|
|
|
overlap = set(self._delta) & set(self._ttl)
|
|
if overlap:
|
|
raise ValueError(
|
|
f"Tables {sorted(overlap)} are in both delta and ttl — a table is "
|
|
"either delta-refreshed (has a change column) or TTL-refreshed (full "
|
|
"reload), not both."
|
|
)
|
|
|
|
if self._delta or self._ttl or self._preload_specs:
|
|
# Startup work (preload of declared tables + delta/TTL catch-up for
|
|
# tables restored from disk) can take a while on a cold start. By
|
|
# default it runs on the background thread so it never blocks
|
|
# application startup; callers who need the cache fully warm before
|
|
# serving can opt back in.
|
|
if blocking_startup_refresh:
|
|
self._preload()
|
|
self._run_refresh()
|
|
self._start_refresh_thread(initial_catch_up=not blocking_startup_refresh)
|
|
|
|
logger.info("CachingEngine initialized.")
|
|
|
|
def _register_indexes(
|
|
self, indexes: dict[str, list[str | list[str]]]
|
|
) -> dict[str, list[str]]:
|
|
"""Register secondary indexes on the cache; return columns to load per table."""
|
|
index_columns: dict[str, list[str]] = {}
|
|
for table, specs in indexes.items():
|
|
wanted: list[str] = []
|
|
for spec in specs:
|
|
columns = [spec] if isinstance(spec, str) else list(spec)
|
|
self._cache.add_index(table, columns)
|
|
for col in columns:
|
|
if col not in wanted:
|
|
wanted.append(col)
|
|
index_columns[table] = wanted
|
|
return index_columns
|
|
|
|
def _resolve_delta(self, delta: dict[str, DeltaConfig]) -> dict[str, ResolvedDelta]:
|
|
"""Resolve each DeltaConfig, auto-discovering the primary key when omitted."""
|
|
resolved: dict[str, ResolvedDelta] = {}
|
|
inspector = None
|
|
for table, cfg in delta.items():
|
|
keys = list(cfg.key_columns)
|
|
if not keys:
|
|
inspector = inspector or inspect(self._source_engine)
|
|
pk = inspector.get_pk_constraint(table)
|
|
keys = list(pk.get("constrained_columns") or [])
|
|
if not keys:
|
|
raise ValueError(
|
|
f"No primary key found for {table!r} in the source DB "
|
|
"(views have none) — set key_columns in its DeltaConfig."
|
|
)
|
|
logger.info(f"Delta {table!r}: auto-discovered key columns {keys}")
|
|
resolved[table] = ResolvedDelta(change_column=cfg.change_column, key_columns=keys)
|
|
return resolved
|
|
|
|
@property
|
|
def stats(self) -> Stats:
|
|
states = self._cache.get_states()
|
|
errors = self._cache.get_errors()
|
|
last_runs = self._cache.get_last_runs()
|
|
with self._cache._lock:
|
|
base = self._stats.snapshot(self._cache.connection, states)
|
|
base = replace(
|
|
base,
|
|
errors=self._cache.error_total,
|
|
db_size_bytes=self._cache.db_size_bytes(),
|
|
)
|
|
return replace(
|
|
base,
|
|
tables={n: self._enrich(n, t, errors, last_runs) for n, t in base.tables.items()},
|
|
)
|
|
|
|
def _enrich(
|
|
self,
|
|
name: str,
|
|
table_stats: TableStats,
|
|
errors: dict[str, TableError],
|
|
last_runs: dict[str, str],
|
|
) -> TableStats:
|
|
"""Annotate a TableStats with refresh tracking, TTL staleness, errors and run time."""
|
|
if name in self._delta:
|
|
tracking = "delta"
|
|
elif name in self._ttl:
|
|
tracking = "ttl"
|
|
else:
|
|
tracking = "static"
|
|
|
|
state = table_stats.state
|
|
if state == TableState.READY and name in self._ttl:
|
|
age = self._cache.seconds_since_refresh(name)
|
|
if age is not None and age > self._ttl[name]:
|
|
state = TableState.STALE
|
|
|
|
last_refresh = last_runs.get(name)
|
|
err = errors.get(name)
|
|
if err is not None:
|
|
return replace(
|
|
table_stats,
|
|
tracking=tracking,
|
|
state=state,
|
|
last_refresh=last_refresh,
|
|
last_error=err.message,
|
|
last_error_at=err.at,
|
|
consecutive_failures=err.consecutive,
|
|
)
|
|
return replace(table_stats, tracking=tracking, state=state, last_refresh=last_refresh)
|
|
|
|
def _make_executor(self, source: Any) -> QueryExecutor:
|
|
return QueryExecutor(
|
|
self._cache,
|
|
self._registry,
|
|
source,
|
|
self._stats,
|
|
self._delta,
|
|
self._ttl,
|
|
self._index_columns,
|
|
)
|
|
|
|
def _check_declared(self, parsed: ParsedQuery) -> None:
|
|
"""In declarative mode, reject any table/column not declared up front."""
|
|
if self._declared is None:
|
|
return
|
|
for table in parsed.tables:
|
|
if table not in self._declared:
|
|
raise UndeclaredError(
|
|
f"Table {table!r} is not declared in tables=[TableSpec(...)]. "
|
|
"Add a TableSpec for it (declarative mode is a strict allowlist)."
|
|
)
|
|
allowed = self._declared[table]
|
|
if allowed is None:
|
|
continue # whole table declared — any column is fine
|
|
if table in parsed.wildcard_tables:
|
|
raise UndeclaredError(
|
|
f"SELECT * on {table!r} is not allowed: only columns {allowed} "
|
|
"are declared. List the columns explicitly or declare "
|
|
"columns=None for the whole table."
|
|
)
|
|
unknown = [c for c in parsed.columns_by_table.get(table, []) if c not in allowed]
|
|
if unknown:
|
|
raise UndeclaredError(
|
|
f"Column(s) {unknown} of {table!r} are not declared "
|
|
f"(declared: {allowed})."
|
|
)
|
|
|
|
def execute(self, sql: str, params: Params = None) -> list[dict]:
|
|
parsed = parse(sql, params, dialect=self._dialect)
|
|
self._check_declared(parsed)
|
|
# The source connection is opened lazily — a pure cache hit never touches
|
|
# the source and never occupies a pool slot.
|
|
source = _LazySource(self._source_engine)
|
|
try:
|
|
return self._make_executor(source).execute(parsed)
|
|
finally:
|
|
source.close()
|
|
|
|
def _preload(self) -> None:
|
|
"""Load declared ``preload=True`` tables into the cache (skipping fresh copies)."""
|
|
if not self._preload_specs:
|
|
return
|
|
source = _LazySource(self._source_engine)
|
|
try:
|
|
executor = self._make_executor(source)
|
|
for spec in self._preload_specs:
|
|
try:
|
|
logger.info(f"Preloading {spec.name!r}…")
|
|
executor.ensure_loaded(spec.name, spec.columns)
|
|
except Exception as e:
|
|
logger.error(f"Preload failed for {spec.name!r}: {e}")
|
|
finally:
|
|
source.close()
|
|
|
|
def refresh(self) -> None:
|
|
"""Pull deltas for all delta-tracked tables now (also runs on a timer)."""
|
|
self._run_refresh()
|
|
|
|
def _run_refresh(self) -> None:
|
|
try:
|
|
with self._source_engine.connect() as sa_conn:
|
|
raw_conn = sa_conn.connection.dbapi_connection
|
|
self._refresher.refresh(raw_conn)
|
|
self._refresh_ttl(raw_conn)
|
|
except Exception as e:
|
|
logger.error(f"Refresh cycle failed: {e}")
|
|
|
|
def _refresh_ttl(self, source_conn: Any) -> None:
|
|
"""Proactively full-reload TTL-tracked tables whose cache has expired."""
|
|
for table, ttl in self._ttl.items():
|
|
if not self._cache.is_table_cached(table):
|
|
continue
|
|
age = self._cache.seconds_since_refresh(table)
|
|
if age is None or age <= ttl:
|
|
continue
|
|
try:
|
|
columns = self._cache.get_table_columns(table)
|
|
full = self._cache.is_table_full(table)
|
|
self._cache.load_table(table, columns, source_conn, full=full)
|
|
logger.info(f"TTL refresh {table!r}: reloaded (age {age:.0f}s > {ttl}s)")
|
|
except Exception as e:
|
|
logger.error(f"TTL refresh failed for {table!r}: {e}")
|
|
|
|
def _start_refresh_thread(self, initial_catch_up: bool = True) -> None:
|
|
def loop() -> None:
|
|
if initial_catch_up:
|
|
self._preload() # off-main-thread declared-table preload
|
|
self._run_refresh() # off-main-thread startup catch-up
|
|
event = threading.Event()
|
|
while not event.wait(self._refresh_interval):
|
|
self._run_refresh()
|
|
|
|
t = threading.Thread(target=loop, daemon=True, name="sqlmem-delta")
|
|
t.start()
|
|
logger.debug(f"Delta refresh thread started (interval={self._refresh_interval}s)")
|
|
|
|
def invalidate(self, table: str) -> None:
|
|
logger.info(f"Manually invalidating cache for table {table!r}")
|
|
with self._cache._lock:
|
|
self._cache.connection.execute(f"DROP TABLE IF EXISTS {quote(table)}")
|
|
self._cache.connection.execute(
|
|
"DELETE FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
|
)
|
|
self._cache.connection.execute(
|
|
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
|
|
)
|
|
self._cache.connection.commit()
|
|
self._cache.clear_state(table)
|
|
|
|
def reset(self) -> None:
|
|
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
|
|
self._cache.reset()
|
|
logger.info("Cache reset — all tables will be reloaded on next use.")
|
|
|
|
def hard_reset(self) -> None:
|
|
"""Delete the on-disk cache file and reopen with current pragmas/page_size.
|
|
|
|
Disk mode only (falls back to :meth:`reset` in memory mode). Use when a
|
|
layout pragma — ``page_size`` or ``auto_vacuum`` — must change, since
|
|
those are baked into the file at creation and :meth:`reset` keeps it.
|
|
All tables reload on next use.
|
|
"""
|
|
self._cache.hard_reset()
|
|
# hard_reset swaps the cache connection — re-point the registry at it.
|
|
self._registry.rebind(self._cache.connection)
|
|
logger.info("Cache hard reset — file recreated; all tables reload on next use.")
|
|
|
|
def vacuum(self, incremental: bool = True, pages: int = 10_000) -> None:
|
|
"""Run maintenance VACUUM on the on-disk cache (incremental by default).
|
|
|
|
Incremental reclaims free pages left by delta ``INSERT OR REPLACE`` churn
|
|
cheaply (requires ``auto_vacuum=INCREMENTAL``); a full VACUUM rewrites the
|
|
whole file and should run only in a maintenance window.
|
|
"""
|
|
self._cache.vacuum(incremental=incremental, pages=pages)
|
|
|
|
def close(self) -> None:
|
|
self._cache.close()
|
|
logger.info("CachingEngine closed.")
|