import threading from dataclasses import replace from pathlib import Path from typing import Any from loguru import logger from sqlalchemy import inspect from sqlalchemy.engine import Connection, Engine from ._sql import quote from .cache import CacheManager, TableError from .config import ( BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH, FETCH_BATCH_SIZE, IN_MEMORY, REFRESH_INTERVAL_SECONDS, SQL_DIALECT, ) from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta from .exceptions import UndeclaredError from .executor import QueryExecutor from .parser import Params, ParsedQuery, parse from .registry import ColumnRegistry from .spec import TTL, TableSpec from .stats import Stats, StatsCollector, TableState, TableStats def _specs_to_config( tables: list[TableSpec], ) -> tuple[ dict[str, DeltaConfig], dict[str, int], dict[str, list[str | list[str]]], dict[str, list[str]], dict[str, list[str] | None], ]: """Convert declarative ``TableSpec``s into the engine's internal config dicts. Returns ``(delta, ttl, indexes, datetime_columns, declared)`` — the first four mirror the legacy kwargs; ``declared`` maps each table to its allowed columns (``None`` = whole table / any column) for fail-fast query checking. """ delta: dict[str, DeltaConfig] = {} ttl: dict[str, int] = {} indexes: dict[str, list[str | list[str]]] = {} datetime_columns: dict[str, list[str]] = {} declared: dict[str, list[str] | None] = {} for spec in tables: if spec.name in declared: raise ValueError(f"Duplicate TableSpec for table {spec.name!r}.") declared[spec.name] = list(spec.columns) if spec.columns is not None else None if spec.indexes: indexes[spec.name] = list(spec.indexes) if spec.datetime_columns: datetime_columns[spec.name] = list(spec.datetime_columns) refresh = spec.refresh if isinstance(refresh, TTL): ttl[spec.name] = refresh.seconds elif isinstance(refresh, DeltaConfig): delta[spec.name] = refresh return delta, ttl, indexes, datetime_columns, declared class _LazySource: """A source connection opened on first ``execute`` and shared across one query. Most queries are cache hits that never touch the source, so opening it (and occupying a connection-pool slot) eagerly is wasteful. This proxy forwards ``execute`` to a real connection opened on demand, then released by ``close``. """ def __init__(self, source_engine: Engine) -> None: self._source_engine = source_engine self._sa_conn: Connection | None = None self._raw: Any = None def execute(self, *args: Any, **kwargs: Any) -> Any: if self._raw is None: self._sa_conn = self._source_engine.connect() self._raw = self._sa_conn.connection.dbapi_connection return self._raw.execute(*args, **kwargs) def close(self) -> None: if self._sa_conn is not None: self._sa_conn.close() self._sa_conn = None self._raw = None class CachingEngine: """Transparent SQLAlchemy-compatible cache layer.""" def __init__( self, source_engine: Engine, delta: dict[str, DeltaConfig] | None = None, ttl: dict[str, int] | None = None, indexes: dict[str, list[str | list[str]]] | None = None, in_memory: bool | None = None, cache_db_path: str | Path | None = None, backup_interval: int | None = None, refresh_interval: int | None = None, fetch_batch: int | None = None, dialect: str | None = None, pragmas: dict[str, str | int] | None = None, datetime_columns: dict[str, list[str]] | None = None, return_datetime: bool = True, tables: list[TableSpec] | None = None, blocking_startup_refresh: bool = False, ) -> None: self._source_engine = source_engine # Declarative mode: a list of TableSpecs is converted to the same internal # config the legacy delta=/ttl=/indexes=/datetime_columns= kwargs produce, # plus a declared-columns allowlist (for fail-fast) and preload set. self._declared: dict[str, list[str] | None] | None = None self._preload_specs: list[TableSpec] = [] if tables is not None: if any(x is not None for x in (delta, ttl, indexes, datetime_columns)): raise ValueError( "Pass either tables=[TableSpec(...)] or the legacy " "delta=/ttl=/indexes=/datetime_columns= kwargs, not both." ) delta, ttl, indexes, datetime_columns, self._declared = _specs_to_config(tables) self._preload_specs = [s for s in tables if s.preload] use_memory = IN_MEMORY if in_memory is None else in_memory self._dialect = dialect if dialect is not None else SQL_DIALECT self._refresh_interval = ( refresh_interval if refresh_interval is not None else REFRESH_INTERVAL_SECONDS ) self._cache = CacheManager( Path(cache_db_path) if cache_db_path is not None else CACHE_DB_PATH, backup_interval if backup_interval is not None else BACKUP_INTERVAL_SECONDS, in_memory=use_memory, dialect=self._dialect, fetch_batch=fetch_batch if fetch_batch is not None else FETCH_BATCH_SIZE, pragmas=pragmas, datetime_columns=datetime_columns, return_datetime=return_datetime, ) self._registry = ColumnRegistry(self._cache.connection) self._stats = StatsCollector() self._delta = self._resolve_delta(delta or {}) self._ttl = dict(ttl or {}) self._index_columns = self._register_indexes(indexes or {}) self._refresher = DeltaRefresher(self._cache, self._delta) overlap = set(self._delta) & set(self._ttl) if overlap: raise ValueError( f"Tables {sorted(overlap)} are in both delta and ttl — a table is " "either delta-refreshed (has a change column) or TTL-refreshed (full " "reload), not both." ) if self._delta or self._ttl or self._preload_specs: # Startup work (preload of declared tables + delta/TTL catch-up for # tables restored from disk) can take a while on a cold start. By # default it runs on the background thread so it never blocks # application startup; callers who need the cache fully warm before # serving can opt back in. if blocking_startup_refresh: self._preload() self._run_refresh() self._start_refresh_thread(initial_catch_up=not blocking_startup_refresh) logger.info("CachingEngine initialized.") def _register_indexes( self, indexes: dict[str, list[str | list[str]]] ) -> dict[str, list[str]]: """Register secondary indexes on the cache; return columns to load per table.""" index_columns: dict[str, list[str]] = {} for table, specs in indexes.items(): wanted: list[str] = [] for spec in specs: columns = [spec] if isinstance(spec, str) else list(spec) self._cache.add_index(table, columns) for col in columns: if col not in wanted: wanted.append(col) index_columns[table] = wanted return index_columns def _resolve_delta(self, delta: dict[str, DeltaConfig]) -> dict[str, ResolvedDelta]: """Resolve each DeltaConfig, auto-discovering the primary key when omitted.""" resolved: dict[str, ResolvedDelta] = {} inspector = None for table, cfg in delta.items(): keys = list(cfg.key_columns) if not keys: inspector = inspector or inspect(self._source_engine) pk = inspector.get_pk_constraint(table) keys = list(pk.get("constrained_columns") or []) if not keys: raise ValueError( f"No primary key found for {table!r} in the source DB " "(views have none) — set key_columns in its DeltaConfig." ) logger.info(f"Delta {table!r}: auto-discovered key columns {keys}") resolved[table] = ResolvedDelta(change_column=cfg.change_column, key_columns=keys) return resolved @property def stats(self) -> Stats: states = self._cache.get_states() errors = self._cache.get_errors() last_runs = self._cache.get_last_runs() with self._cache._lock: base = self._stats.snapshot(self._cache.connection, states) base = replace( base, errors=self._cache.error_total, db_size_bytes=self._cache.db_size_bytes(), ) return replace( base, tables={n: self._enrich(n, t, errors, last_runs) for n, t in base.tables.items()}, ) def _enrich( self, name: str, table_stats: TableStats, errors: dict[str, TableError], last_runs: dict[str, str], ) -> TableStats: """Annotate a TableStats with refresh tracking, TTL staleness, errors and run time.""" if name in self._delta: tracking = "delta" elif name in self._ttl: tracking = "ttl" else: tracking = "static" state = table_stats.state if state == TableState.READY and name in self._ttl: age = self._cache.seconds_since_refresh(name) if age is not None and age > self._ttl[name]: state = TableState.STALE last_refresh = last_runs.get(name) err = errors.get(name) if err is not None: return replace( table_stats, tracking=tracking, state=state, last_refresh=last_refresh, last_error=err.message, last_error_at=err.at, consecutive_failures=err.consecutive, ) return replace(table_stats, tracking=tracking, state=state, last_refresh=last_refresh) def _make_executor(self, source: Any) -> QueryExecutor: return QueryExecutor( self._cache, self._registry, source, self._stats, self._delta, self._ttl, self._index_columns, ) def _check_declared(self, parsed: ParsedQuery) -> None: """In declarative mode, reject any table/column not declared up front.""" if self._declared is None: return for table in parsed.tables: if table not in self._declared: raise UndeclaredError( f"Table {table!r} is not declared in tables=[TableSpec(...)]. " "Add a TableSpec for it (declarative mode is a strict allowlist)." ) allowed = self._declared[table] if allowed is None: continue # whole table declared — any column is fine if table in parsed.wildcard_tables: raise UndeclaredError( f"SELECT * on {table!r} is not allowed: only columns {allowed} " "are declared. List the columns explicitly or declare " "columns=None for the whole table." ) unknown = [c for c in parsed.columns_by_table.get(table, []) if c not in allowed] if unknown: raise UndeclaredError( f"Column(s) {unknown} of {table!r} are not declared " f"(declared: {allowed})." ) def execute(self, sql: str, params: Params = None) -> list[dict]: parsed = parse(sql, params, dialect=self._dialect) self._check_declared(parsed) # The source connection is opened lazily — a pure cache hit never touches # the source and never occupies a pool slot. source = _LazySource(self._source_engine) try: return self._make_executor(source).execute(parsed) finally: source.close() def _preload(self) -> None: """Load declared ``preload=True`` tables into the cache (skipping fresh copies).""" if not self._preload_specs: return source = _LazySource(self._source_engine) try: executor = self._make_executor(source) for spec in self._preload_specs: try: logger.info(f"Preloading {spec.name!r}…") executor.ensure_loaded(spec.name, spec.columns) except Exception as e: logger.error(f"Preload failed for {spec.name!r}: {e}") finally: source.close() def refresh(self) -> None: """Pull deltas for all delta-tracked tables now (also runs on a timer).""" self._run_refresh() def _run_refresh(self) -> None: try: with self._source_engine.connect() as sa_conn: raw_conn = sa_conn.connection.dbapi_connection self._refresher.refresh(raw_conn) self._refresh_ttl(raw_conn) except Exception as e: logger.error(f"Refresh cycle failed: {e}") def _refresh_ttl(self, source_conn: Any) -> None: """Proactively full-reload TTL-tracked tables whose cache has expired.""" for table, ttl in self._ttl.items(): if not self._cache.is_table_cached(table): continue age = self._cache.seconds_since_refresh(table) if age is None or age <= ttl: continue try: columns = self._cache.get_table_columns(table) full = self._cache.is_table_full(table) self._cache.load_table(table, columns, source_conn, full=full) logger.info(f"TTL refresh {table!r}: reloaded (age {age:.0f}s > {ttl}s)") except Exception as e: logger.error(f"TTL refresh failed for {table!r}: {e}") def _start_refresh_thread(self, initial_catch_up: bool = True) -> None: def loop() -> None: if initial_catch_up: self._preload() # off-main-thread declared-table preload self._run_refresh() # off-main-thread startup catch-up event = threading.Event() while not event.wait(self._refresh_interval): self._run_refresh() t = threading.Thread(target=loop, daemon=True, name="sqlmem-delta") t.start() logger.debug(f"Delta refresh thread started (interval={self._refresh_interval}s)") def invalidate(self, table: str) -> None: logger.info(f"Manually invalidating cache for table {table!r}") with self._cache._lock: self._cache.connection.execute(f"DROP TABLE IF EXISTS {quote(table)}") self._cache.connection.execute( "DELETE FROM _sqlmem_tables WHERE table_name = ?", (table,) ) self._cache.connection.execute( "DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,) ) self._cache.connection.commit() self._cache.clear_state(table) def reset(self) -> None: """Wipe the whole cache (RAM + cache.db). Use after structural source changes.""" self._cache.reset() logger.info("Cache reset — all tables will be reloaded on next use.") def hard_reset(self) -> None: """Delete the on-disk cache file and reopen with current pragmas/page_size. Disk mode only (falls back to :meth:`reset` in memory mode). Use when a layout pragma — ``page_size`` or ``auto_vacuum`` — must change, since those are baked into the file at creation and :meth:`reset` keeps it. All tables reload on next use. """ self._cache.hard_reset() # hard_reset swaps the cache connection — re-point the registry at it. self._registry.rebind(self._cache.connection) logger.info("Cache hard reset — file recreated; all tables reload on next use.") def vacuum(self, incremental: bool = True, pages: int = 10_000) -> None: """Run maintenance VACUUM on the on-disk cache (incremental by default). Incremental reclaims free pages left by delta ``INSERT OR REPLACE`` churn cheaply (requires ``auto_vacuum=INCREMENTAL``); a full VACUUM rewrites the whole file and should run only in a maintenance window. """ self._cache.vacuum(incremental=incremental, pages=pages) def close(self) -> None: self._cache.close() logger.info("CachingEngine closed.")