Files
SQLmem/src/sqlmem/engine.py
T

407 lines
17 KiB
Python

import threading
from dataclasses import replace
from pathlib import Path
from typing import Any
from loguru import logger
from sqlalchemy import inspect
from sqlalchemy.engine import Connection, Engine
from ._sql import quote
from .cache import CacheManager, TableError
from .config import (
BACKUP_INTERVAL_SECONDS,
CACHE_DB_PATH,
FETCH_BATCH_SIZE,
IN_MEMORY,
REFRESH_INTERVAL_SECONDS,
SQL_DIALECT,
)
from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
from .exceptions import UndeclaredError
from .executor import QueryExecutor
from .parser import Params, ParsedQuery, parse
from .registry import ColumnRegistry
from .spec import TTL, TableSpec
from .stats import Stats, StatsCollector, TableState, TableStats
def _specs_to_config(
tables: list[TableSpec],
) -> tuple[
dict[str, DeltaConfig],
dict[str, int],
dict[str, list[str | list[str]]],
dict[str, list[str]],
dict[str, list[str] | None],
]:
"""Convert declarative ``TableSpec``s into the engine's internal config dicts.
Returns ``(delta, ttl, indexes, datetime_columns, declared)`` — the first four
mirror the legacy kwargs; ``declared`` maps each table to its allowed columns
(``None`` = whole table / any column) for fail-fast query checking.
"""
delta: dict[str, DeltaConfig] = {}
ttl: dict[str, int] = {}
indexes: dict[str, list[str | list[str]]] = {}
datetime_columns: dict[str, list[str]] = {}
declared: dict[str, list[str] | None] = {}
for spec in tables:
if spec.name in declared:
raise ValueError(f"Duplicate TableSpec for table {spec.name!r}.")
declared[spec.name] = list(spec.columns) if spec.columns is not None else None
if spec.indexes:
indexes[spec.name] = list(spec.indexes)
if spec.datetime_columns:
datetime_columns[spec.name] = list(spec.datetime_columns)
refresh = spec.refresh
if isinstance(refresh, TTL):
ttl[spec.name] = refresh.seconds
elif isinstance(refresh, DeltaConfig):
delta[spec.name] = refresh
return delta, ttl, indexes, datetime_columns, declared
class _LazySource:
"""A source connection opened on first ``execute`` and shared across one query.
Most queries are cache hits that never touch the source, so opening it (and
occupying a connection-pool slot) eagerly is wasteful. This proxy forwards
``execute`` to a real connection opened on demand, then released by ``close``.
"""
def __init__(self, source_engine: Engine) -> None:
self._source_engine = source_engine
self._sa_conn: Connection | None = None
self._raw: Any = None
def execute(self, *args: Any, **kwargs: Any) -> Any:
if self._raw is None:
self._sa_conn = self._source_engine.connect()
self._raw = self._sa_conn.connection.dbapi_connection
return self._raw.execute(*args, **kwargs)
def close(self) -> None:
if self._sa_conn is not None:
self._sa_conn.close()
self._sa_conn = None
self._raw = None
class CachingEngine:
"""Transparent SQLAlchemy-compatible cache layer."""
def __init__(
self,
source_engine: Engine,
delta: dict[str, DeltaConfig] | None = None,
ttl: dict[str, int] | None = None,
indexes: dict[str, list[str | list[str]]] | None = None,
in_memory: bool | None = None,
cache_db_path: str | Path | None = None,
backup_interval: int | None = None,
refresh_interval: int | None = None,
fetch_batch: int | None = None,
dialect: str | None = None,
pragmas: dict[str, str | int] | None = None,
datetime_columns: dict[str, list[str]] | None = None,
return_datetime: bool = True,
tables: list[TableSpec] | None = None,
blocking_startup_refresh: bool = False,
) -> None:
self._source_engine = source_engine
# Declarative mode: a list of TableSpecs is converted to the same internal
# config the legacy delta=/ttl=/indexes=/datetime_columns= kwargs produce,
# plus a declared-columns allowlist (for fail-fast) and preload set.
self._declared: dict[str, list[str] | None] | None = None
self._preload_specs: list[TableSpec] = []
if tables is not None:
if any(x is not None for x in (delta, ttl, indexes, datetime_columns)):
raise ValueError(
"Pass either tables=[TableSpec(...)] or the legacy "
"delta=/ttl=/indexes=/datetime_columns= kwargs, not both."
)
delta, ttl, indexes, datetime_columns, self._declared = _specs_to_config(tables)
self._preload_specs = [s for s in tables if s.preload]
use_memory = IN_MEMORY if in_memory is None else in_memory
self._dialect = dialect if dialect is not None else SQL_DIALECT
self._refresh_interval = (
refresh_interval if refresh_interval is not None else REFRESH_INTERVAL_SECONDS
)
self._cache = CacheManager(
Path(cache_db_path) if cache_db_path is not None else CACHE_DB_PATH,
backup_interval if backup_interval is not None else BACKUP_INTERVAL_SECONDS,
in_memory=use_memory,
dialect=self._dialect,
fetch_batch=fetch_batch if fetch_batch is not None else FETCH_BATCH_SIZE,
pragmas=pragmas,
datetime_columns=datetime_columns,
return_datetime=return_datetime,
)
self._registry = ColumnRegistry(self._cache.connection)
self._stats = StatsCollector()
self._delta = self._resolve_delta(delta or {})
self._ttl = dict(ttl or {})
self._index_columns = self._register_indexes(indexes or {})
self._refresher = DeltaRefresher(self._cache, self._delta)
overlap = set(self._delta) & set(self._ttl)
if overlap:
raise ValueError(
f"Tables {sorted(overlap)} are in both delta and ttl — a table is "
"either delta-refreshed (has a change column) or TTL-refreshed (full "
"reload), not both."
)
if self._delta or self._ttl or self._preload_specs:
# Startup work (preload of declared tables + delta/TTL catch-up for
# tables restored from disk) can take a while on a cold start. By
# default it runs on the background thread so it never blocks
# application startup; callers who need the cache fully warm before
# serving can opt back in.
if blocking_startup_refresh:
self._preload()
self._run_refresh()
self._start_refresh_thread(initial_catch_up=not blocking_startup_refresh)
logger.info("CachingEngine initialized.")
def _register_indexes(
self, indexes: dict[str, list[str | list[str]]]
) -> dict[str, list[str]]:
"""Register secondary indexes on the cache; return columns to load per table."""
index_columns: dict[str, list[str]] = {}
for table, specs in indexes.items():
wanted: list[str] = []
for spec in specs:
columns = [spec] if isinstance(spec, str) else list(spec)
self._cache.add_index(table, columns)
for col in columns:
if col not in wanted:
wanted.append(col)
index_columns[table] = wanted
return index_columns
def _resolve_delta(self, delta: dict[str, DeltaConfig]) -> dict[str, ResolvedDelta]:
"""Resolve each DeltaConfig, auto-discovering the primary key when omitted."""
resolved: dict[str, ResolvedDelta] = {}
inspector = None
for table, cfg in delta.items():
keys = list(cfg.key_columns)
if not keys:
inspector = inspector or inspect(self._source_engine)
pk = inspector.get_pk_constraint(table)
keys = list(pk.get("constrained_columns") or [])
if not keys:
raise ValueError(
f"No primary key found for {table!r} in the source DB "
"(views have none) — set key_columns in its DeltaConfig."
)
logger.info(f"Delta {table!r}: auto-discovered key columns {keys}")
resolved[table] = ResolvedDelta(change_column=cfg.change_column, key_columns=keys)
return resolved
@property
def stats(self) -> Stats:
states = self._cache.get_states()
errors = self._cache.get_errors()
last_runs = self._cache.get_last_runs()
with self._cache._lock:
base = self._stats.snapshot(self._cache.connection, states)
base = replace(
base,
errors=self._cache.error_total,
db_size_bytes=self._cache.db_size_bytes(),
)
return replace(
base,
tables={n: self._enrich(n, t, errors, last_runs) for n, t in base.tables.items()},
)
def _enrich(
self,
name: str,
table_stats: TableStats,
errors: dict[str, TableError],
last_runs: dict[str, str],
) -> TableStats:
"""Annotate a TableStats with refresh tracking, TTL staleness, errors and run time."""
if name in self._delta:
tracking = "delta"
elif name in self._ttl:
tracking = "ttl"
else:
tracking = "static"
state = table_stats.state
if state == TableState.READY and name in self._ttl:
age = self._cache.seconds_since_refresh(name)
if age is not None and age > self._ttl[name]:
state = TableState.STALE
last_refresh = last_runs.get(name)
err = errors.get(name)
if err is not None:
return replace(
table_stats,
tracking=tracking,
state=state,
last_refresh=last_refresh,
last_error=err.message,
last_error_at=err.at,
consecutive_failures=err.consecutive,
)
return replace(table_stats, tracking=tracking, state=state, last_refresh=last_refresh)
def _make_executor(self, source: Any) -> QueryExecutor:
return QueryExecutor(
self._cache,
self._registry,
source,
self._stats,
self._delta,
self._ttl,
self._index_columns,
)
def _check_declared(self, parsed: ParsedQuery) -> None:
"""In declarative mode, reject any table/column not declared up front."""
if self._declared is None:
return
for table in parsed.tables:
if table not in self._declared:
raise UndeclaredError(
f"Table {table!r} is not declared in tables=[TableSpec(...)]. "
"Add a TableSpec for it (declarative mode is a strict allowlist)."
)
allowed = self._declared[table]
if allowed is None:
continue # whole table declared — any column is fine
if table in parsed.wildcard_tables:
raise UndeclaredError(
f"SELECT * on {table!r} is not allowed: only columns {allowed} "
"are declared. List the columns explicitly or declare "
"columns=None for the whole table."
)
unknown = [c for c in parsed.columns_by_table.get(table, []) if c not in allowed]
if unknown:
raise UndeclaredError(
f"Column(s) {unknown} of {table!r} are not declared "
f"(declared: {allowed})."
)
def execute(self, sql: str, params: Params = None) -> list[dict]:
parsed = parse(sql, params, dialect=self._dialect)
self._check_declared(parsed)
# The source connection is opened lazily — a pure cache hit never touches
# the source and never occupies a pool slot.
source = _LazySource(self._source_engine)
try:
return self._make_executor(source).execute(parsed)
finally:
source.close()
def _preload(self) -> None:
"""Load declared ``preload=True`` tables into the cache (skipping fresh copies)."""
if not self._preload_specs:
return
source = _LazySource(self._source_engine)
try:
executor = self._make_executor(source)
for spec in self._preload_specs:
try:
logger.info(f"Preloading {spec.name!r}…")
executor.ensure_loaded(spec.name, spec.columns)
except Exception as e:
logger.error(f"Preload failed for {spec.name!r}: {e}")
finally:
source.close()
def refresh(self) -> None:
"""Pull deltas for all delta-tracked tables now (also runs on a timer)."""
self._run_refresh()
def _run_refresh(self) -> None:
try:
with self._source_engine.connect() as sa_conn:
raw_conn = sa_conn.connection.dbapi_connection
self._refresher.refresh(raw_conn)
self._refresh_ttl(raw_conn)
except Exception as e:
logger.error(f"Refresh cycle failed: {e}")
def _refresh_ttl(self, source_conn: Any) -> None:
"""Proactively full-reload TTL-tracked tables whose cache has expired."""
for table, ttl in self._ttl.items():
if not self._cache.is_table_cached(table):
continue
age = self._cache.seconds_since_refresh(table)
if age is None or age <= ttl:
continue
try:
columns = self._cache.get_table_columns(table)
full = self._cache.is_table_full(table)
self._cache.load_table(table, columns, source_conn, full=full)
logger.info(f"TTL refresh {table!r}: reloaded (age {age:.0f}s > {ttl}s)")
except Exception as e:
logger.error(f"TTL refresh failed for {table!r}: {e}")
def _start_refresh_thread(self, initial_catch_up: bool = True) -> None:
def loop() -> None:
if initial_catch_up:
self._preload() # off-main-thread declared-table preload
self._run_refresh() # off-main-thread startup catch-up
event = threading.Event()
while not event.wait(self._refresh_interval):
self._run_refresh()
t = threading.Thread(target=loop, daemon=True, name="sqlmem-delta")
t.start()
logger.debug(f"Delta refresh thread started (interval={self._refresh_interval}s)")
def invalidate(self, table: str) -> None:
logger.info(f"Manually invalidating cache for table {table!r}")
with self._cache._lock:
self._cache.connection.execute(f"DROP TABLE IF EXISTS {quote(table)}")
self._cache.connection.execute(
"DELETE FROM _sqlmem_tables WHERE table_name = ?", (table,)
)
self._cache.connection.execute(
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
)
self._cache.connection.commit()
self._cache.clear_state(table)
def reset(self) -> None:
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
self._cache.reset()
logger.info("Cache reset — all tables will be reloaded on next use.")
def hard_reset(self) -> None:
"""Delete the on-disk cache file and reopen with current pragmas/page_size.
Disk mode only (falls back to :meth:`reset` in memory mode). Use when a
layout pragma — ``page_size`` or ``auto_vacuum`` — must change, since
those are baked into the file at creation and :meth:`reset` keeps it.
All tables reload on next use.
"""
self._cache.hard_reset()
# hard_reset swaps the cache connection — re-point the registry at it.
self._registry.rebind(self._cache.connection)
logger.info("Cache hard reset — file recreated; all tables reload on next use.")
def vacuum(self, incremental: bool = True, pages: int = 10_000) -> None:
"""Run maintenance VACUUM on the on-disk cache (incremental by default).
Incremental reclaims free pages left by delta ``INSERT OR REPLACE`` churn
cheaply (requires ``auto_vacuum=INCREMENTAL``); a full VACUUM rewrites the
whole file and should run only in a maintenance window.
"""
self._cache.vacuum(incremental=incremental, pages=pages)
def close(self) -> None:
self._cache.close()
logger.info("CachingEngine closed.")