Add declarative TableSpec API with preload and fail-fast; fix shared-connection race

This commit is contained in:
Jan Doubravský
2026-06-11 13:39:56 +02:00
parent 46370fe651
commit 4a86b2282f
11 changed files with 500 additions and 37 deletions
+118 -16
View File
@@ -18,12 +18,50 @@ from .config import (
SQL_DIALECT,
)
from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
from .exceptions import UndeclaredError
from .executor import QueryExecutor
from .parser import Params, parse
from .parser import Params, ParsedQuery, parse
from .registry import ColumnRegistry
from .spec import TTL, TableSpec
from .stats import Stats, StatsCollector, TableState, TableStats
def _specs_to_config(
tables: list[TableSpec],
) -> tuple[
dict[str, DeltaConfig],
dict[str, int],
dict[str, list[str | list[str]]],
dict[str, list[str]],
dict[str, list[str] | None],
]:
"""Convert declarative ``TableSpec``s into the engine's internal config dicts.
Returns ``(delta, ttl, indexes, datetime_columns, declared)`` — the first four
mirror the legacy kwargs; ``declared`` maps each table to its allowed columns
(``None`` = whole table / any column) for fail-fast query checking.
"""
delta: dict[str, DeltaConfig] = {}
ttl: dict[str, int] = {}
indexes: dict[str, list[str | list[str]]] = {}
datetime_columns: dict[str, list[str]] = {}
declared: dict[str, list[str] | None] = {}
for spec in tables:
if spec.name in declared:
raise ValueError(f"Duplicate TableSpec for table {spec.name!r}.")
declared[spec.name] = list(spec.columns) if spec.columns is not None else None
if spec.indexes:
indexes[spec.name] = list(spec.indexes)
if spec.datetime_columns:
datetime_columns[spec.name] = list(spec.datetime_columns)
refresh = spec.refresh
if isinstance(refresh, TTL):
ttl[spec.name] = refresh.seconds
elif isinstance(refresh, DeltaConfig):
delta[spec.name] = refresh
return delta, ttl, indexes, datetime_columns, declared
class _LazySource:
"""A source connection opened on first ``execute`` and shared across one query.
@@ -68,9 +106,25 @@ class CachingEngine:
pragmas: dict[str, str | int] | None = None,
datetime_columns: dict[str, list[str]] | None = None,
return_datetime: bool = True,
tables: list[TableSpec] | None = None,
blocking_startup_refresh: bool = False,
) -> None:
self._source_engine = source_engine
# Declarative mode: a list of TableSpecs is converted to the same internal
# config the legacy delta=/ttl=/indexes=/datetime_columns= kwargs produce,
# plus a declared-columns allowlist (for fail-fast) and preload set.
self._declared: dict[str, list[str] | None] | None = None
self._preload_specs: list[TableSpec] = []
if tables is not None:
if any(x is not None for x in (delta, ttl, indexes, datetime_columns)):
raise ValueError(
"Pass either tables=[TableSpec(...)] or the legacy "
"delta=/ttl=/indexes=/datetime_columns= kwargs, not both."
)
delta, ttl, indexes, datetime_columns, self._declared = _specs_to_config(tables)
self._preload_specs = [s for s in tables if s.preload]
use_memory = IN_MEMORY if in_memory is None else in_memory
self._dialect = dialect if dialect is not None else SQL_DIALECT
self._refresh_interval = (
@@ -101,12 +155,14 @@ class CachingEngine:
"reload), not both."
)
if self._delta or self._ttl:
# The startup catch-up (deltas/TTL reloads for tables restored from
# disk) can take a while on a cold start. By default it runs on the
# background thread so it never blocks application startup; callers
# who need the cache fully fresh before serving can opt back in.
if self._delta or self._ttl or self._preload_specs:
# Startup work (preload of declared tables + delta/TTL catch-up for
# tables restored from disk) can take a while on a cold start. By
# default it runs on the background thread so it never blocks
# application startup; callers who need the cache fully warm before
# serving can opt back in.
if blocking_startup_refresh:
self._preload()
self._run_refresh()
self._start_refresh_thread(initial_catch_up=not blocking_startup_refresh)
@@ -199,22 +255,67 @@ class CachingEngine:
)
return replace(table_stats, tracking=tracking, state=state, last_refresh=last_refresh)
def _make_executor(self, source: Any) -> QueryExecutor:
return QueryExecutor(
self._cache,
self._registry,
source,
self._stats,
self._delta,
self._ttl,
self._index_columns,
)
def _check_declared(self, parsed: ParsedQuery) -> None:
"""In declarative mode, reject any table/column not declared up front."""
if self._declared is None:
return
for table in parsed.tables:
if table not in self._declared:
raise UndeclaredError(
f"Table {table!r} is not declared in tables=[TableSpec(...)]. "
"Add a TableSpec for it (declarative mode is a strict allowlist)."
)
allowed = self._declared[table]
if allowed is None:
continue # whole table declared — any column is fine
if table in parsed.wildcard_tables:
raise UndeclaredError(
f"SELECT * on {table!r} is not allowed: only columns {allowed} "
"are declared. List the columns explicitly or declare "
"columns=None for the whole table."
)
unknown = [c for c in parsed.columns_by_table.get(table, []) if c not in allowed]
if unknown:
raise UndeclaredError(
f"Column(s) {unknown} of {table!r} are not declared "
f"(declared: {allowed})."
)
def execute(self, sql: str, params: Params = None) -> list[dict]:
parsed = parse(sql, params, dialect=self._dialect)
self._check_declared(parsed)
# The source connection is opened lazily — a pure cache hit never touches
# the source and never occupies a pool slot.
source = _LazySource(self._source_engine)
try:
executor = QueryExecutor(
self._cache,
self._registry,
source,
self._stats,
self._delta,
self._ttl,
self._index_columns,
)
return executor.execute(parsed)
return self._make_executor(source).execute(parsed)
finally:
source.close()
def _preload(self) -> None:
"""Load declared ``preload=True`` tables into the cache (skipping fresh copies)."""
if not self._preload_specs:
return
source = _LazySource(self._source_engine)
try:
executor = self._make_executor(source)
for spec in self._preload_specs:
try:
logger.info(f"Preloading {spec.name!r}")
executor.ensure_loaded(spec.name, spec.columns)
except Exception as e:
logger.error(f"Preload failed for {spec.name!r}: {e}")
finally:
source.close()
@@ -250,6 +351,7 @@ class CachingEngine:
def _start_refresh_thread(self, initial_catch_up: bool = True) -> None:
def loop() -> None:
if initial_catch_up:
self._preload() # off-main-thread declared-table preload
self._run_refresh() # off-main-thread startup catch-up
event = threading.Event()
while not event.wait(self._refresh_interval):