import atexit import signal import sqlite3 import threading from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from loguru import logger import sqlmem._meta as _meta from ._coerce import coerce_params, coerce_row from ._sql import quote, quote_list, quote_source from .config import FETCH_BATCH_SIZE, SQL_DIALECT from .stats import TableState SCHEMA_VERSION = 3 @dataclass(frozen=True) class _Index: name: str columns: tuple[str, ...] @dataclass(frozen=True) class TableError: """Most recent load/refresh failure for a table (see ``CacheManager.get_errors``).""" message: str at: str consecutive: int class CacheManager: def __init__( self, db_path: Path, backup_interval: int, in_memory: bool = True, dialect: str = SQL_DIALECT, fetch_batch: int = FETCH_BATCH_SIZE, ) -> None: self._db_path = db_path self._backup_interval = backup_interval self._in_memory = in_memory self._dialect = dialect # source-DB dialect, for identifier quoting self._fetch_batch = fetch_batch # rows fetched per source batch self._lock = threading.Lock() # serializes connection access self._load_lock = threading.Lock() # serializes full table loads self._states: dict[str, str] = {} # table → live processing state self._errors: dict[str, TableError] = {} # table → last load/refresh failure self._error_total = 0 # process-wide failure counter self._index_defs: dict[str, list[_Index]] = {} # table → secondary indexes self._read_local = threading.local() # per-thread read conn (disk mode) self._read_conns: list[sqlite3.Connection] = [] # read conns, for cleanup self._closed = False if in_memory: self._conn = sqlite3.connect(":memory:", check_same_thread=False) else: # Disk-backed: query the on-disk file directly — no RAM copy, every # write persists immediately, and the cache can exceed available RAM. self._conn = sqlite3.connect(str(db_path), check_same_thread=False) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA synchronous=NORMAL") self._discard_if_schema_mismatch() self._ensure_meta_tables() if in_memory: self._load_from_disk() self._drop_orphan_staging() if in_memory: self._start_backup_thread() atexit.register(self._backup_to_disk) signal.signal(signal.SIGTERM, self._on_sigterm) else: atexit.register(self.close) @property def connection(self) -> sqlite3.Connection: return self._conn def _ensure_meta_tables(self) -> None: self._conn.executescript(""" CREATE TABLE IF NOT EXISTS _sqlmem_meta ( key TEXT PRIMARY KEY, value TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS _sqlmem_tables ( table_name TEXT PRIMARY KEY, last_refresh_at TEXT NOT NULL, row_count INTEGER, is_full INTEGER NOT NULL DEFAULT 0, last_synced_at TEXT ); CREATE TABLE IF NOT EXISTS _sqlmem_columns ( table_name TEXT NOT NULL, column_name TEXT NOT NULL, PRIMARY KEY (table_name, column_name) ); """) self._conn.execute( "INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)", ("app_version", _meta.__version__), ) self._conn.execute( "INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)", ("schema_version", str(SCHEMA_VERSION)), ) self._conn.execute( "INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)", ("created_at", _now()), ) self._conn.commit() def _discard_if_schema_mismatch(self) -> None: """Disk mode: wipe an existing cache file written by an incompatible schema. In memory mode the equivalent check lives in :meth:`_load_from_disk`; here we operate on the live on-disk connection, dropping every table so the meta tables are recreated fresh by :meth:`_ensure_meta_tables`. """ meta_exists = self._conn.execute( "SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = '_sqlmem_meta'" ).fetchone() if not meta_exists: return # fresh file — nothing to validate row = self._conn.execute( "SELECT value FROM _sqlmem_meta WHERE key = 'schema_version'" ).fetchone() if row is not None and int(row[0]) == SCHEMA_VERSION: return logger.warning( "Cache schema version mismatch — wiping on-disk cache, starting fresh." ) names = [ r[0] for r in self._conn.execute( r"SELECT name FROM sqlite_master WHERE type = 'table' " r"AND name NOT LIKE 'sqlite\_%' ESCAPE '\'" ).fetchall() ] for name in names: self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}") self._conn.commit() def _load_from_disk(self) -> None: if not self._db_path.exists(): logger.info(f"No cache file found at {self._db_path}, starting fresh.") return logger.info(f"Loading cache from {self._db_path}") disk_conn = sqlite3.connect(self._db_path) try: schema_version = disk_conn.execute( "SELECT value FROM _sqlmem_meta WHERE key = 'schema_version'" ).fetchone() if schema_version is None or int(schema_version[0]) != SCHEMA_VERSION: logger.warning("Cache schema version mismatch — discarding cache file, starting fresh.") disk_conn.close() return disk_conn.backup(self._conn) logger.info("Cache loaded from disk successfully.") except Exception as e: logger.error(f"Failed to load cache from disk: {e} — starting fresh.") finally: disk_conn.close() def _drop_orphan_staging(self) -> None: """Drop staging tables left by a load that was interrupted (e.g. crash mid-load).""" orphans = [ r[0] for r in self._conn.execute( r"SELECT name FROM sqlite_master " r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'" ).fetchall() ] for name in orphans: logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.") self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}") if orphans: self._conn.commit() def _backup_to_disk(self) -> None: if self._closed: return if not self._in_memory: # Disk-backed: every write already lands on disk; just flush the WAL. with self._lock: self._conn.commit() return logger.info(f"Backing up cache to {self._db_path}") try: with self._lock: disk_conn = sqlite3.connect(self._db_path) self._conn.backup(disk_conn) disk_conn.close() logger.info("Cache backup complete.") except Exception as e: logger.error(f"Cache backup failed: {e}") def _start_backup_thread(self) -> None: def loop() -> None: event = threading.Event() while not event.wait(self._backup_interval): self._backup_to_disk() t = threading.Thread(target=loop, daemon=True, name="sqlmem-backup") t.start() logger.debug(f"Backup thread started (interval={self._backup_interval}s)") def _on_sigterm(self, signum, frame) -> None: logger.info("SIGTERM received — flushing cache to disk.") self._backup_to_disk() def mark_table_refreshed(self, table: str, row_count: int, full: bool = False) -> None: with self._lock: self._conn.execute( """ INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count, is_full) VALUES (?, ?, ?, ?) ON CONFLICT(table_name) DO UPDATE SET last_refresh_at = excluded.last_refresh_at, row_count = excluded.row_count, is_full = excluded.is_full """, (table, _now(), row_count, int(full)), ) self._conn.commit() def is_table_cached(self, table: str) -> bool: row = self._conn.execute( "SELECT 1 FROM _sqlmem_tables WHERE table_name = ?", (table,) ).fetchone() return row is not None def is_table_full(self, table: str) -> bool: """True if the whole table (all columns) is cached — a SELECT * cache hit.""" row = self._conn.execute( "SELECT is_full FROM _sqlmem_tables WHERE table_name = ?", (table,) ).fetchone() return bool(row and row[0]) def seconds_since_refresh(self, table: str) -> float | None: """Age of a cached table in seconds, or None if it is not cached.""" row = self._conn.execute( "SELECT last_refresh_at FROM _sqlmem_tables WHERE table_name = ?", (table,) ).fetchone() if not row or not row[0]: return None last = datetime.fromisoformat(row[0]) return (datetime.now(timezone.utc) - last).total_seconds() def discover_columns(self, table: str, source_conn: sqlite3.Connection) -> list[str]: """Return all column names of *table* from the source DB without fetching rows.""" logger.debug(f"Discovering columns of {table!r} from source DB") cursor = source_conn.execute( f"SELECT * FROM {quote_source(table, self._dialect)} WHERE 1 = 0" ) columns = [desc[0] for desc in cursor.description] logger.debug(f"{table!r} has columns: {columns}") return columns def set_state(self, table: str, state: str) -> None: self._states[table] = state def get_states(self) -> dict[str, str]: return dict(self._states) def clear_state(self, table: str) -> None: self._states.pop(table, None) self._errors.pop(table, None) def record_error(self, table: str, message: str) -> None: """Record a load/refresh failure for *table* (increments its failure streak).""" prev = self._errors.get(table) streak = (prev.consecutive if prev else 0) + 1 self._errors[table] = TableError(message=message, at=_now(), consecutive=streak) self._error_total += 1 logger.debug(f"Recorded error for {table!r} (streak {streak}): {message}") def record_success(self, table: str) -> None: """Reset *table*'s failure streak to 0 after a successful load/refresh.""" prev = self._errors.get(table) if prev and prev.consecutive: self._errors[table] = TableError(prev.message, prev.at, 0) def get_errors(self) -> dict[str, TableError]: return dict(self._errors) @property def error_total(self) -> int: return self._error_total def add_index(self, table: str, columns: list[str]) -> None: """Register a secondary index to (re)create on *columns* after each load.""" name = "sqlmem_idx_" + "_".join([table, *columns]) defs = self._index_defs.setdefault(table, []) if all(d.name != name for d in defs): defs.append(_Index(name=name, columns=tuple(columns))) def _create_indexes(self, table: str, available: list[str]) -> None: """Create the registered secondary indexes whose columns are all cached.""" available_set = set(available) for idx in self._index_defs.get(table, []): if not set(idx.columns) <= available_set: logger.warning( f"Skipping index {idx.name!r}: columns {idx.columns} not all cached." ) continue cols = quote_list(idx.columns) with self._lock: self._conn.execute( f"CREATE INDEX IF NOT EXISTS {quote(idx.name)} ON {quote(table)} ({cols})" ) self._conn.commit() logger.debug(f"Index {idx.name!r} ready on {table} ({cols})") def load_table( self, table: str, columns: list[str], source_conn: sqlite3.Connection, full: bool = False, ) -> None: """Stream the source table into the cache in batches. Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging table and swapped in atomically, so peak memory stays bounded (no ``fetchall`` of a huge table) and readers keep seeing the previous copy until the swap. Concurrent loads are serialized by ``_load_lock``; the connection lock is only held for the brief per-batch inserts and the swap. """ src_cols = ", ".join(quote_source(c, self._dialect) for c in columns) col_defs = ", ".join(f"{quote(c)} TEXT" for c in columns) placeholders = ", ".join("?" * len(columns)) staging = f"{table}__sqlmem_load" q_staging = quote(staging) q_table = quote(table) with self._load_lock: self.set_state(table, TableState.LOADING) logger.info(f"Fetching {table!r} columns {columns} from source DB (batch={self._fetch_batch})") try: cursor = source_conn.execute( f"SELECT {src_cols} FROM {quote_source(table, self._dialect)}" ) with self._lock: self._conn.execute(f"DROP TABLE IF EXISTS {q_staging}") self._conn.execute(f"CREATE TABLE {q_staging} ({col_defs})") self._conn.commit() total = 0 insert_sql = f"INSERT INTO {q_staging} VALUES ({placeholders})" while True: batch = cursor.fetchmany(self._fetch_batch) # network outside _lock if not batch: break clean = [coerce_row(row) for row in batch] with self._lock: self._conn.executemany(insert_sql, clean) self._conn.commit() total += len(batch) with self._lock: # atomic swap — readers see old or new, never partial self._conn.execute(f"DROP TABLE IF EXISTS {q_table}") self._conn.execute(f"ALTER TABLE {q_staging} RENAME TO {q_table}") self._conn.commit() except BaseException as exc: with self._lock: self._conn.execute(f"DROP TABLE IF EXISTS {q_staging}") self._conn.commit() self.set_state(table, TableState.ERROR) self.record_error(table, f"{type(exc).__name__}: {exc}") raise self._create_indexes(table, columns) self.mark_table_refreshed(table, total, full) self.set_state(table, TableState.READY) self.record_success(table) logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})") def _read_conn(self) -> sqlite3.Connection: """A per-thread, read-only connection used for cache reads in disk mode. Disk mode runs in WAL, which allows many concurrent readers alongside one writer. Giving each thread its own read connection (rather than sharing the single write connection under ``_lock``) means a slow ``SELECT`` no longer blocks writers (loads/upserts) or other readers. In-memory mode can't do this — each ``:memory:`` connection is a separate database — so it keeps using the single locked connection. """ conn = getattr(self._read_local, "conn", None) if conn is None: conn = sqlite3.connect(str(self._db_path), check_same_thread=False) conn.execute("PRAGMA query_only=ON") # read-only guard self._read_local.conn = conn with self._lock: self._read_conns.append(conn) return conn def execute_in_memory( self, sql: str, params: tuple | list | dict | None = None ) -> tuple[list[str], list[tuple]]: """Run a read query against the cache. In-memory mode serializes with writers on the single connection. Disk mode reads from a per-thread WAL connection, so reads run concurrently with writers and each other (see :meth:`_read_conn`). """ bound = coerce_params(params) if self._in_memory: with self._lock: cursor = ( self._conn.execute(sql) if bound is None else self._conn.execute(sql, bound) ) col_names = [desc[0] for desc in cursor.description] rows = cursor.fetchall() return col_names, rows conn = self._read_conn() cursor = conn.execute(sql) if bound is None else conn.execute(sql, bound) col_names = [desc[0] for desc in cursor.description] rows = cursor.fetchall() return col_names, rows # --- delta refresh support --------------------------------------------- def get_table_columns(self, table: str) -> list[str]: """Authoritative ordered column list of a cached table (via PRAGMA).""" rows = self._conn.execute(f"PRAGMA table_info({quote(table)})").fetchall() return [r[1] for r in rows] def create_unique_index(self, table: str, key_columns: list[str]) -> None: """Create the unique index on *key_columns* that makes upsert-by-key work.""" cols = quote_list(key_columns) index = quote(f"idx_{table}_pk") with self._lock: self._conn.execute( f"CREATE UNIQUE INDEX IF NOT EXISTS {index} ON {quote(table)} ({cols})" ) self._conn.commit() def get_last_synced_at(self, table: str) -> str | None: row = self._conn.execute( "SELECT last_synced_at FROM _sqlmem_tables WHERE table_name = ?", (table,) ).fetchone() return row[0] if row else None def set_last_synced_at(self, table: str, value: str | None) -> None: with self._lock: self._conn.execute( "UPDATE _sqlmem_tables SET last_synced_at = ? WHERE table_name = ?", (value, table), ) self._conn.commit() def max_value(self, table: str, column: str) -> str | None: """Maximum value of *column* across cached rows (the delta watermark).""" row = self._conn.execute( f"SELECT MAX({quote(column)}) FROM {quote(table)}" ).fetchone() return row[0] if row else None def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None: """Insert-or-replace one batch of *rows* by the table's unique key.""" col_list = quote_list(columns) placeholders = ", ".join("?" * len(columns)) clean_rows = [coerce_row(row) for row in rows] with self._lock: self._conn.executemany( f"INSERT OR REPLACE INTO {quote(table)} ({col_list}) VALUES ({placeholders})", clean_rows, ) self._conn.commit() def count_rows(self, table: str) -> int: row = self._conn.execute(f"SELECT COUNT(*) FROM {quote(table)}").fetchone() return int(row[0]) if row else 0 def reset(self) -> None: """Wipe the entire cache — every cached table plus the on-disk data (the file is deleted in memory mode, VACUUMed in place in disk mode).""" logger.info("Resetting cache — dropping all cached tables.") with self._lock: user_tables = [ r[0] for r in self._conn.execute( "SELECT name FROM sqlite_master " r"WHERE type = 'table' AND name NOT LIKE 'sqlite\_%' ESCAPE '\' " r"AND name NOT LIKE '\_sqlmem\_%' ESCAPE '\'" ).fetchall() ] for name in user_tables: self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}") self._conn.execute("DELETE FROM _sqlmem_tables") self._conn.execute("DELETE FROM _sqlmem_columns") self._conn.commit() self._states.clear() if self._in_memory: try: if self._db_path.exists(): self._db_path.unlink() except OSError as e: logger.error(f"Failed to delete cache file {self._db_path}: {e}") else: # The open connection *is* the file — drop tables persisted the wipe; # VACUUM reclaims the freed pages on disk. try: with self._lock: self._conn.execute("VACUUM") except sqlite3.Error as e: logger.error(f"Failed to VACUUM cache file {self._db_path}: {e}") def close(self) -> None: self._backup_to_disk() self._closed = True with self._lock: for conn in self._read_conns: try: conn.close() except sqlite3.Error: pass self._read_conns.clear() self._conn.close() def _now() -> str: return datetime.now(timezone.utc).isoformat()