Add incremental delta refresh and fix Decimal/datetime cache binding

This commit is contained in:
Jan Doubravský
2026-06-05 11:09:16 +02:00
parent 530c2618cf
commit 33aa126ff6
13 changed files with 798 additions and 53 deletions
+10 -1
View File
@@ -3,6 +3,7 @@ from typing import Any
from loguru import logger
from .config import DEBUG
from .delta import DeltaConfig
from .engine import CachingEngine
from .exceptions import ReadOnlyError, UnsupportedQueryError
from .stats import Stats, TableStats
@@ -35,4 +36,12 @@ def add_sink(sink: Any, *, level: str | None = None, **kwargs: Any) -> None:
logger.add(sink, level=level or ("DEBUG" if DEBUG else "INFO"), filter="sqlmem", **kwargs)
__all__ = ["CachingEngine", "ReadOnlyError", "UnsupportedQueryError", "Stats", "TableStats", "add_sink"]
__all__ = [
"CachingEngine",
"DeltaConfig",
"ReadOnlyError",
"UnsupportedQueryError",
"Stats",
"TableStats",
"add_sink",
]
+40
View File
@@ -0,0 +1,40 @@
"""Coerce source-DB values into types ``sqlite3`` can bind.
pyodbc returns ``NUMERIC``/``DECIMAL``/``MONEY`` as :class:`decimal.Decimal` and
date/time columns as :mod:`datetime` objects, none of which ``sqlite3`` binds
natively. Cache columns are ``TEXT``, so stringifying is lossless and consistent
with how the data is stored. This is done **locally** — never via a global
``sqlite3.register_adapter`` — so the host application's ``sqlite3`` behaviour is
left untouched.
"""
import datetime
import decimal
import uuid
from typing import Any
Params = tuple | list | dict | None
def to_sqlite(value: Any) -> Any:
if isinstance(value, decimal.Decimal):
return str(value)
if isinstance(value, (datetime.datetime, datetime.date, datetime.time)):
return value.isoformat()
if isinstance(value, uuid.UUID):
return str(value)
if isinstance(value, bytearray):
return bytes(value)
return value
def coerce_row(row: tuple) -> tuple:
return tuple(to_sqlite(v) for v in row)
def coerce_params(params: Params) -> tuple | dict | None:
if params is None:
return None
if isinstance(params, dict):
return {key: to_sqlite(val) for key, val in params.items()}
return tuple(to_sqlite(val) for val in params)
+90 -3
View File
@@ -8,8 +8,9 @@ from pathlib import Path
from loguru import logger
import sqlmem._meta as _meta
from ._coerce import coerce_params, coerce_row
SCHEMA_VERSION = 2
SCHEMA_VERSION = 3
class CacheManager:
@@ -41,7 +42,8 @@ class CacheManager:
table_name TEXT PRIMARY KEY,
last_refresh_at TEXT NOT NULL,
row_count INTEGER,
is_full INTEGER NOT NULL DEFAULT 0
is_full INTEGER NOT NULL DEFAULT 0,
last_synced_at TEXT
);
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
table_name TEXT NOT NULL,
@@ -159,18 +161,103 @@ class CacheManager:
cols = ", ".join(columns)
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
clean_rows = [coerce_row(row) for row in rows]
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
col_defs = ", ".join(f"{c} TEXT" for c in columns)
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
placeholders = ", ".join("?" * len(columns))
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", rows)
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows)
self._mem_conn.commit()
self.mark_table_refreshed(table, len(rows), full)
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
def execute_in_memory(
self, sql: str, params: tuple | list | dict | None = None
) -> tuple[list[str], list[tuple]]:
"""Run a read query against the in-memory cache, serialized with writers."""
bound = coerce_params(params)
with self._lock:
cursor = self._mem_conn.execute(sql) if bound is None else self._mem_conn.execute(sql, bound)
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return col_names, rows
# --- delta refresh support ---------------------------------------------
def get_table_columns(self, table: str) -> list[str]:
"""Authoritative ordered column list of a cached table (via PRAGMA)."""
rows = self._mem_conn.execute(f"PRAGMA table_info({table})").fetchall()
return [r[1] for r in rows]
def create_unique_index(self, table: str, key_columns: list[str]) -> None:
"""Create the unique index on *key_columns* that makes upsert-by-key work."""
cols = ", ".join(key_columns)
index = f"idx_{table}_pk"
with self._lock:
self._mem_conn.execute(
f"CREATE UNIQUE INDEX IF NOT EXISTS {index} ON {table} ({cols})"
)
self._mem_conn.commit()
def get_last_synced_at(self, table: str) -> str | None:
row = self._mem_conn.execute(
"SELECT last_synced_at FROM _sqlmem_tables WHERE table_name = ?", (table,)
).fetchone()
return row[0] if row else None
def set_last_synced_at(self, table: str, value: str | None) -> None:
with self._lock:
self._mem_conn.execute(
"UPDATE _sqlmem_tables SET last_synced_at = ? WHERE table_name = ?",
(value, table),
)
self._mem_conn.commit()
def max_value(self, table: str, column: str) -> str | None:
"""Maximum value of *column* across cached rows (the delta watermark)."""
row = self._mem_conn.execute(f"SELECT MAX({column}) FROM {table}").fetchone()
return row[0] if row else None
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count."""
col_list = ", ".join(columns)
placeholders = ", ".join("?" * len(columns))
clean_rows = [coerce_row(row) for row in rows]
with self._lock:
self._mem_conn.executemany(
f"INSERT OR REPLACE INTO {table} ({col_list}) VALUES ({placeholders})",
clean_rows,
)
self._mem_conn.commit()
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
self.mark_table_refreshed(table, count, self.is_table_full(table))
def reset(self) -> None:
"""Wipe the entire cache — every cached table plus the on-disk file."""
logger.info("Resetting cache — dropping all cached tables.")
with self._lock:
user_tables = [
r[0]
for r in self._mem_conn.execute(
"SELECT name FROM sqlite_master "
r"WHERE type = 'table' AND name NOT LIKE 'sqlite\_%' ESCAPE '\' "
r"AND name NOT LIKE '\_sqlmem\_%' ESCAPE '\'"
).fetchall()
]
for name in user_tables:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
self._mem_conn.commit()
try:
if self._db_path.exists():
self._db_path.unlink()
except OSError as e:
logger.error(f"Failed to delete cache file {self._db_path}: {e}")
def close(self) -> None:
self._backup_to_disk()
self._closed = True
+2
View File
@@ -9,6 +9,8 @@ load_dotenv()
DEBUG = os.getenv("SQLMEM_DEBUG", "false").lower() == "true"
CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
# How often (seconds) the background thread pulls deltas for delta-tracked tables.
REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300"))
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
+78
View File
@@ -0,0 +1,78 @@
import sqlite3
from dataclasses import dataclass, field
from loguru import logger
from .cache import CacheManager
@dataclass(frozen=True)
class DeltaConfig:
"""Per-table configuration for incremental (delta) refresh.
*change_column* is the column the source DB updates on every insert/update
(a non-decreasing timestamp / rowversion). *key_columns* uniquely identify a
row and are used to upsert changed rows in place; leave them empty to let the
engine auto-discover the primary key from the source DB (works for real
tables, not views).
"""
change_column: str
key_columns: list[str] = field(default_factory=list)
@dataclass(frozen=True)
class ResolvedDelta:
"""A :class:`DeltaConfig` with ``key_columns`` resolved to concrete columns."""
change_column: str
key_columns: list[str]
class DeltaRefresher:
"""Pulls only changed rows for delta-tracked tables and upserts them.
Uses a data-driven high-watermark (``max(change_column)`` actually cached)
with a ``>=`` overlap and idempotent upsert by key, so no row is ever missed
and boundary rows are harmlessly re-read.
"""
def __init__(self, cache: CacheManager, delta: dict[str, ResolvedDelta]) -> None:
self._cache = cache
self._delta = delta
def refresh(self, source_conn: sqlite3.Connection) -> None:
for table, cfg in self._delta.items():
if not self._cache.is_table_cached(table):
continue
try:
self._refresh_table(table, cfg, source_conn)
except Exception as e: # one bad table must not stop the others
logger.error(f"Delta refresh failed for {table!r}: {e}")
def _refresh_table(
self, table: str, cfg: ResolvedDelta, source_conn: sqlite3.Connection
) -> None:
columns = self._cache.get_table_columns(table)
watermark = self._cache.get_last_synced_at(table)
col_list = ", ".join(columns)
if watermark is None:
rows = source_conn.execute(f"SELECT {col_list} FROM {table}").fetchall()
else:
rows = source_conn.execute(
f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?",
(watermark,),
).fetchall()
if not rows:
logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}")
return
self._cache.upsert_rows(table, columns, rows)
new_watermark = self._cache.max_value(table, cfg.change_column)
self._cache.set_last_synced_at(table, new_watermark)
logger.info(
f"Delta refresh {table!r}: {len(rows)} row(s) upserted, "
f"watermark {watermark!r}{new_watermark!r}"
)
+68 -4
View File
@@ -1,11 +1,14 @@
import sqlite3
import threading
from typing import cast
from loguru import logger
from sqlalchemy import inspect
from sqlalchemy.engine import Engine
from .cache import CacheManager
from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH
from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH, REFRESH_INTERVAL_SECONDS
from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
from .executor import QueryExecutor
from .parser import Params, parse
from .registry import ColumnRegistry
@@ -15,24 +18,80 @@ from .stats import Stats, StatsCollector
class CachingEngine:
"""Transparent SQLAlchemy-compatible cache layer."""
def __init__(self, source_engine: Engine) -> None:
def __init__(
self,
source_engine: Engine,
delta: dict[str, DeltaConfig] | None = None,
) -> None:
self._source_engine = source_engine
self._cache = CacheManager(CACHE_DB_PATH, BACKUP_INTERVAL_SECONDS)
self._registry = ColumnRegistry(self._cache.connection)
self._stats = StatsCollector()
self._refresh_interval = REFRESH_INTERVAL_SECONDS
self._delta = self._resolve_delta(delta or {})
self._refresher = DeltaRefresher(self._cache, self._delta)
if self._delta:
self._run_refresh() # catch up tables restored from disk
self._start_refresh_thread()
logger.info("CachingEngine initialized.")
def _resolve_delta(self, delta: dict[str, DeltaConfig]) -> dict[str, ResolvedDelta]:
"""Resolve each DeltaConfig, auto-discovering the primary key when omitted."""
resolved: dict[str, ResolvedDelta] = {}
inspector = None
for table, cfg in delta.items():
keys = list(cfg.key_columns)
if not keys:
inspector = inspector or inspect(self._source_engine)
pk = inspector.get_pk_constraint(table)
keys = list(pk.get("constrained_columns") or [])
if not keys:
raise ValueError(
f"No primary key found for {table!r} in the source DB "
"(views have none) — set key_columns in its DeltaConfig."
)
logger.info(f"Delta {table!r}: auto-discovered key columns {keys}")
resolved[table] = ResolvedDelta(change_column=cfg.change_column, key_columns=keys)
return resolved
@property
def stats(self) -> Stats:
return self._stats.snapshot(self._cache.connection)
with self._cache._lock:
return self._stats.snapshot(self._cache.connection)
def execute(self, sql: str, params: Params = None) -> list[dict]:
parsed = parse(sql, params)
with self._source_engine.connect() as sa_conn:
raw_conn = cast(sqlite3.Connection, sa_conn.connection.dbapi_connection)
executor = QueryExecutor(self._cache, self._registry, raw_conn, self._stats)
executor = QueryExecutor(
self._cache, self._registry, raw_conn, self._stats, self._delta
)
return executor.execute(parsed)
def refresh(self) -> None:
"""Pull deltas for all delta-tracked tables now (also runs on a timer)."""
self._run_refresh()
def _run_refresh(self) -> None:
try:
with self._source_engine.connect() as sa_conn:
raw_conn = cast(sqlite3.Connection, sa_conn.connection.dbapi_connection)
self._refresher.refresh(raw_conn)
except Exception as e:
logger.error(f"Delta refresh cycle failed: {e}")
def _start_refresh_thread(self) -> None:
def loop() -> None:
event = threading.Event()
while not event.wait(self._refresh_interval):
self._run_refresh()
t = threading.Thread(target=loop, daemon=True, name="sqlmem-delta")
t.start()
logger.debug(f"Delta refresh thread started (interval={self._refresh_interval}s)")
def invalidate(self, table: str) -> None:
logger.info(f"Manually invalidating cache for table {table!r}")
with self._cache._lock:
@@ -45,6 +104,11 @@ class CachingEngine:
)
self._cache.connection.commit()
def reset(self) -> None:
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
self._cache.reset()
logger.info("Cache reset — all tables will be reloaded on next use.")
def close(self) -> None:
self._cache.close()
logger.info("CachingEngine closed.")
+22 -11
View File
@@ -3,6 +3,7 @@ import sqlite3
from loguru import logger
from .cache import CacheManager
from .delta import ResolvedDelta
from .parser import ParsedQuery
from .registry import ColumnRegistry
from .stats import StatsCollector
@@ -15,11 +16,13 @@ class QueryExecutor:
registry: ColumnRegistry,
source_conn: sqlite3.Connection,
stats: StatsCollector,
delta: dict[str, ResolvedDelta] | None = None,
) -> None:
self._cache = cache
self._registry = registry
self._source_conn = source_conn
self._stats = stats
self._delta = delta or {}
def execute(self, parsed: ParsedQuery) -> list[dict]:
for table in parsed.tables:
@@ -46,8 +49,7 @@ class QueryExecutor:
self._stats.record_miss()
columns = self._cache.discover_columns(table, self._source_conn)
self._cache.load_table(table, columns, self._source_conn, full=True)
self._registry.update(table, columns)
self._load(table, columns, full=True)
def _ensure_columns(self, table: str, columns: list[str]) -> None:
"""Load *table* with at least *columns*, refetching only when columns are missing."""
@@ -69,16 +71,25 @@ class QueryExecutor:
self._stats.record_miss()
all_columns = list(self._registry.get_columns(table)) + missing
self._cache.load_table(table, all_columns, self._source_conn)
self._registry.update(table, all_columns)
self._load(table, all_columns, full=False)
def _load(self, table: str, columns: list[str], full: bool) -> None:
"""Fetch *table* into cache, adding delta key/timestamp columns when tracked."""
cfg = self._delta.get(table)
if cfg:
# The cache must always hold the key (to upsert) and the change column
# (to compute the watermark), even if no query referenced them.
columns = list(dict.fromkeys([*columns, *cfg.key_columns, cfg.change_column]))
self._cache.load_table(table, columns, self._source_conn, full=full)
self._registry.update(table, columns)
if cfg:
self._cache.create_unique_index(table, cfg.key_columns)
watermark = self._cache.max_value(table, cfg.change_column)
self._cache.set_last_synced_at(table, watermark)
def _run_in_memory(self, parsed: ParsedQuery) -> list[dict]:
logger.debug(f"Executing in SQLite RAM: {parsed.sqlite_sql!r} params={parsed.params!r}")
conn = self._cache.connection
if parsed.params is None:
cursor = conn.execute(parsed.sqlite_sql)
else:
cursor = conn.execute(parsed.sqlite_sql, parsed.params)
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
col_names, rows = self._cache.execute_in_memory(parsed.sqlite_sql, parsed.params)
return [dict(zip(col_names, row)) for row in rows]