Add incremental delta refresh and fix Decimal/datetime cache binding
This commit is contained in:
+10
-1
@@ -3,6 +3,7 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from .config import DEBUG
|
||||
from .delta import DeltaConfig
|
||||
from .engine import CachingEngine
|
||||
from .exceptions import ReadOnlyError, UnsupportedQueryError
|
||||
from .stats import Stats, TableStats
|
||||
@@ -35,4 +36,12 @@ def add_sink(sink: Any, *, level: str | None = None, **kwargs: Any) -> None:
|
||||
logger.add(sink, level=level or ("DEBUG" if DEBUG else "INFO"), filter="sqlmem", **kwargs)
|
||||
|
||||
|
||||
__all__ = ["CachingEngine", "ReadOnlyError", "UnsupportedQueryError", "Stats", "TableStats", "add_sink"]
|
||||
__all__ = [
|
||||
"CachingEngine",
|
||||
"DeltaConfig",
|
||||
"ReadOnlyError",
|
||||
"UnsupportedQueryError",
|
||||
"Stats",
|
||||
"TableStats",
|
||||
"add_sink",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Coerce source-DB values into types ``sqlite3`` can bind.
|
||||
|
||||
pyodbc returns ``NUMERIC``/``DECIMAL``/``MONEY`` as :class:`decimal.Decimal` and
|
||||
date/time columns as :mod:`datetime` objects, none of which ``sqlite3`` binds
|
||||
natively. Cache columns are ``TEXT``, so stringifying is lossless and consistent
|
||||
with how the data is stored. This is done **locally** — never via a global
|
||||
``sqlite3.register_adapter`` — so the host application's ``sqlite3`` behaviour is
|
||||
left untouched.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
Params = tuple | list | dict | None
|
||||
|
||||
|
||||
def to_sqlite(value: Any) -> Any:
|
||||
if isinstance(value, decimal.Decimal):
|
||||
return str(value)
|
||||
if isinstance(value, (datetime.datetime, datetime.date, datetime.time)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, uuid.UUID):
|
||||
return str(value)
|
||||
if isinstance(value, bytearray):
|
||||
return bytes(value)
|
||||
return value
|
||||
|
||||
|
||||
def coerce_row(row: tuple) -> tuple:
|
||||
return tuple(to_sqlite(v) for v in row)
|
||||
|
||||
|
||||
def coerce_params(params: Params) -> tuple | dict | None:
|
||||
if params is None:
|
||||
return None
|
||||
if isinstance(params, dict):
|
||||
return {key: to_sqlite(val) for key, val in params.items()}
|
||||
return tuple(to_sqlite(val) for val in params)
|
||||
+90
-3
@@ -8,8 +8,9 @@ from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
import sqlmem._meta as _meta
|
||||
from ._coerce import coerce_params, coerce_row
|
||||
|
||||
SCHEMA_VERSION = 2
|
||||
SCHEMA_VERSION = 3
|
||||
|
||||
|
||||
class CacheManager:
|
||||
@@ -41,7 +42,8 @@ class CacheManager:
|
||||
table_name TEXT PRIMARY KEY,
|
||||
last_refresh_at TEXT NOT NULL,
|
||||
row_count INTEGER,
|
||||
is_full INTEGER NOT NULL DEFAULT 0
|
||||
is_full INTEGER NOT NULL DEFAULT 0,
|
||||
last_synced_at TEXT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
|
||||
table_name TEXT NOT NULL,
|
||||
@@ -159,18 +161,103 @@ class CacheManager:
|
||||
cols = ", ".join(columns)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
|
||||
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
|
||||
clean_rows = [coerce_row(row) for row in rows]
|
||||
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
||||
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", rows)
|
||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows)
|
||||
self._mem_conn.commit()
|
||||
|
||||
self.mark_table_refreshed(table, len(rows), full)
|
||||
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
|
||||
|
||||
def execute_in_memory(
|
||||
self, sql: str, params: tuple | list | dict | None = None
|
||||
) -> tuple[list[str], list[tuple]]:
|
||||
"""Run a read query against the in-memory cache, serialized with writers."""
|
||||
bound = coerce_params(params)
|
||||
with self._lock:
|
||||
cursor = self._mem_conn.execute(sql) if bound is None else self._mem_conn.execute(sql, bound)
|
||||
col_names = [desc[0] for desc in cursor.description]
|
||||
rows = cursor.fetchall()
|
||||
return col_names, rows
|
||||
|
||||
# --- delta refresh support ---------------------------------------------
|
||||
|
||||
def get_table_columns(self, table: str) -> list[str]:
|
||||
"""Authoritative ordered column list of a cached table (via PRAGMA)."""
|
||||
rows = self._mem_conn.execute(f"PRAGMA table_info({table})").fetchall()
|
||||
return [r[1] for r in rows]
|
||||
|
||||
def create_unique_index(self, table: str, key_columns: list[str]) -> None:
|
||||
"""Create the unique index on *key_columns* that makes upsert-by-key work."""
|
||||
cols = ", ".join(key_columns)
|
||||
index = f"idx_{table}_pk"
|
||||
with self._lock:
|
||||
self._mem_conn.execute(
|
||||
f"CREATE UNIQUE INDEX IF NOT EXISTS {index} ON {table} ({cols})"
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
|
||||
def get_last_synced_at(self, table: str) -> str | None:
|
||||
row = self._mem_conn.execute(
|
||||
"SELECT last_synced_at FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
||||
).fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def set_last_synced_at(self, table: str, value: str | None) -> None:
|
||||
with self._lock:
|
||||
self._mem_conn.execute(
|
||||
"UPDATE _sqlmem_tables SET last_synced_at = ? WHERE table_name = ?",
|
||||
(value, table),
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
|
||||
def max_value(self, table: str, column: str) -> str | None:
|
||||
"""Maximum value of *column* across cached rows (the delta watermark)."""
|
||||
row = self._mem_conn.execute(f"SELECT MAX({column}) FROM {table}").fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
|
||||
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count."""
|
||||
col_list = ", ".join(columns)
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
clean_rows = [coerce_row(row) for row in rows]
|
||||
with self._lock:
|
||||
self._mem_conn.executemany(
|
||||
f"INSERT OR REPLACE INTO {table} ({col_list}) VALUES ({placeholders})",
|
||||
clean_rows,
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
|
||||
self.mark_table_refreshed(table, count, self.is_table_full(table))
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Wipe the entire cache — every cached table plus the on-disk file."""
|
||||
logger.info("Resetting cache — dropping all cached tables.")
|
||||
with self._lock:
|
||||
user_tables = [
|
||||
r[0]
|
||||
for r in self._mem_conn.execute(
|
||||
"SELECT name FROM sqlite_master "
|
||||
r"WHERE type = 'table' AND name NOT LIKE 'sqlite\_%' ESCAPE '\' "
|
||||
r"AND name NOT LIKE '\_sqlmem\_%' ESCAPE '\'"
|
||||
).fetchall()
|
||||
]
|
||||
for name in user_tables:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
|
||||
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
|
||||
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
|
||||
self._mem_conn.commit()
|
||||
try:
|
||||
if self._db_path.exists():
|
||||
self._db_path.unlink()
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to delete cache file {self._db_path}: {e}")
|
||||
|
||||
def close(self) -> None:
|
||||
self._backup_to_disk()
|
||||
self._closed = True
|
||||
|
||||
@@ -9,6 +9,8 @@ load_dotenv()
|
||||
DEBUG = os.getenv("SQLMEM_DEBUG", "false").lower() == "true"
|
||||
CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
|
||||
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
|
||||
# How often (seconds) the background thread pulls deltas for delta-tracked tables.
|
||||
REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300"))
|
||||
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
|
||||
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
|
||||
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import sqlite3
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .cache import CacheManager
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeltaConfig:
|
||||
"""Per-table configuration for incremental (delta) refresh.
|
||||
|
||||
*change_column* is the column the source DB updates on every insert/update
|
||||
(a non-decreasing timestamp / rowversion). *key_columns* uniquely identify a
|
||||
row and are used to upsert changed rows in place; leave them empty to let the
|
||||
engine auto-discover the primary key from the source DB (works for real
|
||||
tables, not views).
|
||||
"""
|
||||
|
||||
change_column: str
|
||||
key_columns: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedDelta:
|
||||
"""A :class:`DeltaConfig` with ``key_columns`` resolved to concrete columns."""
|
||||
|
||||
change_column: str
|
||||
key_columns: list[str]
|
||||
|
||||
|
||||
class DeltaRefresher:
|
||||
"""Pulls only changed rows for delta-tracked tables and upserts them.
|
||||
|
||||
Uses a data-driven high-watermark (``max(change_column)`` actually cached)
|
||||
with a ``>=`` overlap and idempotent upsert by key, so no row is ever missed
|
||||
and boundary rows are harmlessly re-read.
|
||||
"""
|
||||
|
||||
def __init__(self, cache: CacheManager, delta: dict[str, ResolvedDelta]) -> None:
|
||||
self._cache = cache
|
||||
self._delta = delta
|
||||
|
||||
def refresh(self, source_conn: sqlite3.Connection) -> None:
|
||||
for table, cfg in self._delta.items():
|
||||
if not self._cache.is_table_cached(table):
|
||||
continue
|
||||
try:
|
||||
self._refresh_table(table, cfg, source_conn)
|
||||
except Exception as e: # one bad table must not stop the others
|
||||
logger.error(f"Delta refresh failed for {table!r}: {e}")
|
||||
|
||||
def _refresh_table(
|
||||
self, table: str, cfg: ResolvedDelta, source_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
columns = self._cache.get_table_columns(table)
|
||||
watermark = self._cache.get_last_synced_at(table)
|
||||
col_list = ", ".join(columns)
|
||||
|
||||
if watermark is None:
|
||||
rows = source_conn.execute(f"SELECT {col_list} FROM {table}").fetchall()
|
||||
else:
|
||||
rows = source_conn.execute(
|
||||
f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?",
|
||||
(watermark,),
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}")
|
||||
return
|
||||
|
||||
self._cache.upsert_rows(table, columns, rows)
|
||||
new_watermark = self._cache.max_value(table, cfg.change_column)
|
||||
self._cache.set_last_synced_at(table, new_watermark)
|
||||
logger.info(
|
||||
f"Delta refresh {table!r}: {len(rows)} row(s) upserted, "
|
||||
f"watermark {watermark!r} → {new_watermark!r}"
|
||||
)
|
||||
+68
-4
@@ -1,11 +1,14 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import cast
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from .cache import CacheManager
|
||||
from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH
|
||||
from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH, REFRESH_INTERVAL_SECONDS
|
||||
from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
|
||||
from .executor import QueryExecutor
|
||||
from .parser import Params, parse
|
||||
from .registry import ColumnRegistry
|
||||
@@ -15,24 +18,80 @@ from .stats import Stats, StatsCollector
|
||||
class CachingEngine:
|
||||
"""Transparent SQLAlchemy-compatible cache layer."""
|
||||
|
||||
def __init__(self, source_engine: Engine) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
source_engine: Engine,
|
||||
delta: dict[str, DeltaConfig] | None = None,
|
||||
) -> None:
|
||||
self._source_engine = source_engine
|
||||
self._cache = CacheManager(CACHE_DB_PATH, BACKUP_INTERVAL_SECONDS)
|
||||
self._registry = ColumnRegistry(self._cache.connection)
|
||||
self._stats = StatsCollector()
|
||||
self._refresh_interval = REFRESH_INTERVAL_SECONDS
|
||||
self._delta = self._resolve_delta(delta or {})
|
||||
self._refresher = DeltaRefresher(self._cache, self._delta)
|
||||
|
||||
if self._delta:
|
||||
self._run_refresh() # catch up tables restored from disk
|
||||
self._start_refresh_thread()
|
||||
|
||||
logger.info("CachingEngine initialized.")
|
||||
|
||||
def _resolve_delta(self, delta: dict[str, DeltaConfig]) -> dict[str, ResolvedDelta]:
|
||||
"""Resolve each DeltaConfig, auto-discovering the primary key when omitted."""
|
||||
resolved: dict[str, ResolvedDelta] = {}
|
||||
inspector = None
|
||||
for table, cfg in delta.items():
|
||||
keys = list(cfg.key_columns)
|
||||
if not keys:
|
||||
inspector = inspector or inspect(self._source_engine)
|
||||
pk = inspector.get_pk_constraint(table)
|
||||
keys = list(pk.get("constrained_columns") or [])
|
||||
if not keys:
|
||||
raise ValueError(
|
||||
f"No primary key found for {table!r} in the source DB "
|
||||
"(views have none) — set key_columns in its DeltaConfig."
|
||||
)
|
||||
logger.info(f"Delta {table!r}: auto-discovered key columns {keys}")
|
||||
resolved[table] = ResolvedDelta(change_column=cfg.change_column, key_columns=keys)
|
||||
return resolved
|
||||
|
||||
@property
|
||||
def stats(self) -> Stats:
|
||||
return self._stats.snapshot(self._cache.connection)
|
||||
with self._cache._lock:
|
||||
return self._stats.snapshot(self._cache.connection)
|
||||
|
||||
def execute(self, sql: str, params: Params = None) -> list[dict]:
|
||||
parsed = parse(sql, params)
|
||||
with self._source_engine.connect() as sa_conn:
|
||||
raw_conn = cast(sqlite3.Connection, sa_conn.connection.dbapi_connection)
|
||||
executor = QueryExecutor(self._cache, self._registry, raw_conn, self._stats)
|
||||
executor = QueryExecutor(
|
||||
self._cache, self._registry, raw_conn, self._stats, self._delta
|
||||
)
|
||||
return executor.execute(parsed)
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Pull deltas for all delta-tracked tables now (also runs on a timer)."""
|
||||
self._run_refresh()
|
||||
|
||||
def _run_refresh(self) -> None:
|
||||
try:
|
||||
with self._source_engine.connect() as sa_conn:
|
||||
raw_conn = cast(sqlite3.Connection, sa_conn.connection.dbapi_connection)
|
||||
self._refresher.refresh(raw_conn)
|
||||
except Exception as e:
|
||||
logger.error(f"Delta refresh cycle failed: {e}")
|
||||
|
||||
def _start_refresh_thread(self) -> None:
|
||||
def loop() -> None:
|
||||
event = threading.Event()
|
||||
while not event.wait(self._refresh_interval):
|
||||
self._run_refresh()
|
||||
|
||||
t = threading.Thread(target=loop, daemon=True, name="sqlmem-delta")
|
||||
t.start()
|
||||
logger.debug(f"Delta refresh thread started (interval={self._refresh_interval}s)")
|
||||
|
||||
def invalidate(self, table: str) -> None:
|
||||
logger.info(f"Manually invalidating cache for table {table!r}")
|
||||
with self._cache._lock:
|
||||
@@ -45,6 +104,11 @@ class CachingEngine:
|
||||
)
|
||||
self._cache.connection.commit()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
|
||||
self._cache.reset()
|
||||
logger.info("Cache reset — all tables will be reloaded on next use.")
|
||||
|
||||
def close(self) -> None:
|
||||
self._cache.close()
|
||||
logger.info("CachingEngine closed.")
|
||||
|
||||
+22
-11
@@ -3,6 +3,7 @@ import sqlite3
|
||||
from loguru import logger
|
||||
|
||||
from .cache import CacheManager
|
||||
from .delta import ResolvedDelta
|
||||
from .parser import ParsedQuery
|
||||
from .registry import ColumnRegistry
|
||||
from .stats import StatsCollector
|
||||
@@ -15,11 +16,13 @@ class QueryExecutor:
|
||||
registry: ColumnRegistry,
|
||||
source_conn: sqlite3.Connection,
|
||||
stats: StatsCollector,
|
||||
delta: dict[str, ResolvedDelta] | None = None,
|
||||
) -> None:
|
||||
self._cache = cache
|
||||
self._registry = registry
|
||||
self._source_conn = source_conn
|
||||
self._stats = stats
|
||||
self._delta = delta or {}
|
||||
|
||||
def execute(self, parsed: ParsedQuery) -> list[dict]:
|
||||
for table in parsed.tables:
|
||||
@@ -46,8 +49,7 @@ class QueryExecutor:
|
||||
self._stats.record_miss()
|
||||
|
||||
columns = self._cache.discover_columns(table, self._source_conn)
|
||||
self._cache.load_table(table, columns, self._source_conn, full=True)
|
||||
self._registry.update(table, columns)
|
||||
self._load(table, columns, full=True)
|
||||
|
||||
def _ensure_columns(self, table: str, columns: list[str]) -> None:
|
||||
"""Load *table* with at least *columns*, refetching only when columns are missing."""
|
||||
@@ -69,16 +71,25 @@ class QueryExecutor:
|
||||
self._stats.record_miss()
|
||||
|
||||
all_columns = list(self._registry.get_columns(table)) + missing
|
||||
self._cache.load_table(table, all_columns, self._source_conn)
|
||||
self._registry.update(table, all_columns)
|
||||
self._load(table, all_columns, full=False)
|
||||
|
||||
def _load(self, table: str, columns: list[str], full: bool) -> None:
|
||||
"""Fetch *table* into cache, adding delta key/timestamp columns when tracked."""
|
||||
cfg = self._delta.get(table)
|
||||
if cfg:
|
||||
# The cache must always hold the key (to upsert) and the change column
|
||||
# (to compute the watermark), even if no query referenced them.
|
||||
columns = list(dict.fromkeys([*columns, *cfg.key_columns, cfg.change_column]))
|
||||
|
||||
self._cache.load_table(table, columns, self._source_conn, full=full)
|
||||
self._registry.update(table, columns)
|
||||
|
||||
if cfg:
|
||||
self._cache.create_unique_index(table, cfg.key_columns)
|
||||
watermark = self._cache.max_value(table, cfg.change_column)
|
||||
self._cache.set_last_synced_at(table, watermark)
|
||||
|
||||
def _run_in_memory(self, parsed: ParsedQuery) -> list[dict]:
|
||||
logger.debug(f"Executing in SQLite RAM: {parsed.sqlite_sql!r} params={parsed.params!r}")
|
||||
conn = self._cache.connection
|
||||
if parsed.params is None:
|
||||
cursor = conn.execute(parsed.sqlite_sql)
|
||||
else:
|
||||
cursor = conn.execute(parsed.sqlite_sql, parsed.params)
|
||||
col_names = [desc[0] for desc in cursor.description]
|
||||
rows = cursor.fetchall()
|
||||
col_names, rows = self._cache.execute_in_memory(parsed.sqlite_sql, parsed.params)
|
||||
return [dict(zip(col_names, row)) for row in rows]
|
||||
|
||||
Reference in New Issue
Block a user