Files
SQLmem/src/sqlmem/cache.py
T

537 lines
22 KiB
Python

import atexit
import signal
import sqlite3
import threading
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from loguru import logger
import sqlmem._meta as _meta
from ._coerce import coerce_params, coerce_row
from ._sql import quote, quote_list, quote_source
from .config import FETCH_BATCH_SIZE, SQL_DIALECT
from .stats import TableState
SCHEMA_VERSION = 3
@dataclass(frozen=True)
class _Index:
name: str
columns: tuple[str, ...]
@dataclass(frozen=True)
class TableError:
"""Most recent load/refresh failure for a table (see ``CacheManager.get_errors``)."""
message: str
at: str
consecutive: int
class CacheManager:
def __init__(
self,
db_path: Path,
backup_interval: int,
in_memory: bool = True,
dialect: str = SQL_DIALECT,
fetch_batch: int = FETCH_BATCH_SIZE,
) -> None:
self._db_path = db_path
self._backup_interval = backup_interval
self._in_memory = in_memory
self._dialect = dialect # source-DB dialect, for identifier quoting
self._fetch_batch = fetch_batch # rows fetched per source batch
self._lock = threading.Lock() # serializes connection access
self._load_lock = threading.Lock() # serializes full table loads
self._states: dict[str, str] = {} # table → live processing state
self._errors: dict[str, TableError] = {} # table → last load/refresh failure
self._error_total = 0 # process-wide failure counter
self._index_defs: dict[str, list[_Index]] = {} # table → secondary indexes
self._read_local = threading.local() # per-thread read conn (disk mode)
self._read_conns: list[sqlite3.Connection] = [] # read conns, for cleanup
self._closed = False
if in_memory:
self._conn = sqlite3.connect(":memory:", check_same_thread=False)
else:
# Disk-backed: query the on-disk file directly — no RAM copy, every
# write persists immediately, and the cache can exceed available RAM.
self._conn = sqlite3.connect(str(db_path), check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA synchronous=NORMAL")
self._discard_if_schema_mismatch()
self._ensure_meta_tables()
if in_memory:
self._load_from_disk()
self._drop_orphan_staging()
if in_memory:
self._start_backup_thread()
atexit.register(self._backup_to_disk)
signal.signal(signal.SIGTERM, self._on_sigterm)
else:
atexit.register(self.close)
@property
def connection(self) -> sqlite3.Connection:
return self._conn
def _ensure_meta_tables(self) -> None:
self._conn.executescript("""
CREATE TABLE IF NOT EXISTS _sqlmem_meta (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS _sqlmem_tables (
table_name TEXT PRIMARY KEY,
last_refresh_at TEXT NOT NULL,
row_count INTEGER,
is_full INTEGER NOT NULL DEFAULT 0,
last_synced_at TEXT
);
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
table_name TEXT NOT NULL,
column_name TEXT NOT NULL,
PRIMARY KEY (table_name, column_name)
);
""")
self._conn.execute(
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
("app_version", _meta.__version__),
)
self._conn.execute(
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
("schema_version", str(SCHEMA_VERSION)),
)
self._conn.execute(
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
("created_at", _now()),
)
self._conn.commit()
def _discard_if_schema_mismatch(self) -> None:
"""Disk mode: wipe an existing cache file written by an incompatible schema.
In memory mode the equivalent check lives in :meth:`_load_from_disk`; here
we operate on the live on-disk connection, dropping every table so the
meta tables are recreated fresh by :meth:`_ensure_meta_tables`.
"""
meta_exists = self._conn.execute(
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = '_sqlmem_meta'"
).fetchone()
if not meta_exists:
return # fresh file — nothing to validate
row = self._conn.execute(
"SELECT value FROM _sqlmem_meta WHERE key = 'schema_version'"
).fetchone()
if row is not None and int(row[0]) == SCHEMA_VERSION:
return
logger.warning(
"Cache schema version mismatch — wiping on-disk cache, starting fresh."
)
names = [
r[0]
for r in self._conn.execute(
r"SELECT name FROM sqlite_master WHERE type = 'table' "
r"AND name NOT LIKE 'sqlite\_%' ESCAPE '\'"
).fetchall()
]
for name in names:
self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}")
self._conn.commit()
def _load_from_disk(self) -> None:
if not self._db_path.exists():
logger.info(f"No cache file found at {self._db_path}, starting fresh.")
return
logger.info(f"Loading cache from {self._db_path}")
disk_conn = sqlite3.connect(self._db_path)
try:
schema_version = disk_conn.execute(
"SELECT value FROM _sqlmem_meta WHERE key = 'schema_version'"
).fetchone()
if schema_version is None or int(schema_version[0]) != SCHEMA_VERSION:
logger.warning("Cache schema version mismatch — discarding cache file, starting fresh.")
disk_conn.close()
return
disk_conn.backup(self._conn)
logger.info("Cache loaded from disk successfully.")
except Exception as e:
logger.error(f"Failed to load cache from disk: {e} — starting fresh.")
finally:
disk_conn.close()
def _drop_orphan_staging(self) -> None:
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
orphans = [
r[0]
for r in self._conn.execute(
r"SELECT name FROM sqlite_master "
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
).fetchall()
]
for name in orphans:
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}")
if orphans:
self._conn.commit()
def _backup_to_disk(self) -> None:
if self._closed:
return
if not self._in_memory:
# Disk-backed: every write already lands on disk; just flush the WAL.
with self._lock:
self._conn.commit()
return
logger.info(f"Backing up cache to {self._db_path}")
try:
with self._lock:
disk_conn = sqlite3.connect(self._db_path)
self._conn.backup(disk_conn)
disk_conn.close()
logger.info("Cache backup complete.")
except Exception as e:
logger.error(f"Cache backup failed: {e}")
def _start_backup_thread(self) -> None:
def loop() -> None:
event = threading.Event()
while not event.wait(self._backup_interval):
self._backup_to_disk()
t = threading.Thread(target=loop, daemon=True, name="sqlmem-backup")
t.start()
logger.debug(f"Backup thread started (interval={self._backup_interval}s)")
def _on_sigterm(self, signum, frame) -> None:
logger.info("SIGTERM received — flushing cache to disk.")
self._backup_to_disk()
def mark_table_refreshed(self, table: str, row_count: int, full: bool = False) -> None:
with self._lock:
self._conn.execute(
"""
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count, is_full)
VALUES (?, ?, ?, ?)
ON CONFLICT(table_name) DO UPDATE SET
last_refresh_at = excluded.last_refresh_at,
row_count = excluded.row_count,
is_full = excluded.is_full
""",
(table, _now(), row_count, int(full)),
)
self._conn.commit()
def is_table_cached(self, table: str) -> bool:
row = self._conn.execute(
"SELECT 1 FROM _sqlmem_tables WHERE table_name = ?", (table,)
).fetchone()
return row is not None
def is_table_full(self, table: str) -> bool:
"""True if the whole table (all columns) is cached — a SELECT * cache hit."""
row = self._conn.execute(
"SELECT is_full FROM _sqlmem_tables WHERE table_name = ?", (table,)
).fetchone()
return bool(row and row[0])
def seconds_since_refresh(self, table: str) -> float | None:
"""Age of a cached table in seconds, or None if it is not cached."""
row = self._conn.execute(
"SELECT last_refresh_at FROM _sqlmem_tables WHERE table_name = ?", (table,)
).fetchone()
if not row or not row[0]:
return None
last = datetime.fromisoformat(row[0])
return (datetime.now(timezone.utc) - last).total_seconds()
def discover_columns(self, table: str, source_conn: sqlite3.Connection) -> list[str]:
"""Return all column names of *table* from the source DB without fetching rows."""
logger.debug(f"Discovering columns of {table!r} from source DB")
cursor = source_conn.execute(
f"SELECT * FROM {quote_source(table, self._dialect)} WHERE 1 = 0"
)
columns = [desc[0] for desc in cursor.description]
logger.debug(f"{table!r} has columns: {columns}")
return columns
def set_state(self, table: str, state: str) -> None:
self._states[table] = state
def get_states(self) -> dict[str, str]:
return dict(self._states)
def clear_state(self, table: str) -> None:
self._states.pop(table, None)
self._errors.pop(table, None)
def record_error(self, table: str, message: str) -> None:
"""Record a load/refresh failure for *table* (increments its failure streak)."""
prev = self._errors.get(table)
streak = (prev.consecutive if prev else 0) + 1
self._errors[table] = TableError(message=message, at=_now(), consecutive=streak)
self._error_total += 1
logger.debug(f"Recorded error for {table!r} (streak {streak}): {message}")
def record_success(self, table: str) -> None:
"""Reset *table*'s failure streak to 0 after a successful load/refresh."""
prev = self._errors.get(table)
if prev and prev.consecutive:
self._errors[table] = TableError(prev.message, prev.at, 0)
def get_errors(self) -> dict[str, TableError]:
return dict(self._errors)
@property
def error_total(self) -> int:
return self._error_total
def add_index(self, table: str, columns: list[str]) -> None:
"""Register a secondary index to (re)create on *columns* after each load."""
name = "sqlmem_idx_" + "_".join([table, *columns])
defs = self._index_defs.setdefault(table, [])
if all(d.name != name for d in defs):
defs.append(_Index(name=name, columns=tuple(columns)))
def _create_indexes(self, table: str, available: list[str]) -> None:
"""Create the registered secondary indexes whose columns are all cached."""
available_set = set(available)
for idx in self._index_defs.get(table, []):
if not set(idx.columns) <= available_set:
logger.warning(
f"Skipping index {idx.name!r}: columns {idx.columns} not all cached."
)
continue
cols = quote_list(idx.columns)
with self._lock:
self._conn.execute(
f"CREATE INDEX IF NOT EXISTS {quote(idx.name)} ON {quote(table)} ({cols})"
)
self._conn.commit()
logger.debug(f"Index {idx.name!r} ready on {table} ({cols})")
def load_table(
self,
table: str,
columns: list[str],
source_conn: sqlite3.Connection,
full: bool = False,
) -> None:
"""Stream the source table into the cache in batches.
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
table and swapped in atomically, so peak memory stays bounded (no
``fetchall`` of a huge table) and readers keep seeing the previous copy
until the swap. Concurrent loads are serialized by ``_load_lock``; the
connection lock is only held for the brief per-batch inserts and the swap.
"""
src_cols = ", ".join(quote_source(c, self._dialect) for c in columns)
col_defs = ", ".join(f"{quote(c)} TEXT" for c in columns)
placeholders = ", ".join("?" * len(columns))
staging = f"{table}__sqlmem_load"
q_staging = quote(staging)
q_table = quote(table)
with self._load_lock:
self.set_state(table, TableState.LOADING)
logger.info(f"Fetching {table!r} columns {columns} from source DB (batch={self._fetch_batch})")
try:
cursor = source_conn.execute(
f"SELECT {src_cols} FROM {quote_source(table, self._dialect)}"
)
with self._lock:
self._conn.execute(f"DROP TABLE IF EXISTS {q_staging}")
self._conn.execute(f"CREATE TABLE {q_staging} ({col_defs})")
self._conn.commit()
total = 0
insert_sql = f"INSERT INTO {q_staging} VALUES ({placeholders})"
while True:
batch = cursor.fetchmany(self._fetch_batch) # network outside _lock
if not batch:
break
clean = [coerce_row(row) for row in batch]
with self._lock:
self._conn.executemany(insert_sql, clean)
self._conn.commit()
total += len(batch)
with self._lock: # atomic swap — readers see old or new, never partial
self._conn.execute(f"DROP TABLE IF EXISTS {q_table}")
self._conn.execute(f"ALTER TABLE {q_staging} RENAME TO {q_table}")
self._conn.commit()
except BaseException as exc:
with self._lock:
self._conn.execute(f"DROP TABLE IF EXISTS {q_staging}")
self._conn.commit()
self.set_state(table, TableState.ERROR)
self.record_error(table, f"{type(exc).__name__}: {exc}")
raise
self._create_indexes(table, columns)
self.mark_table_refreshed(table, total, full)
self.set_state(table, TableState.READY)
self.record_success(table)
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
def _read_conn(self) -> sqlite3.Connection:
"""A per-thread, read-only connection used for cache reads in disk mode.
Disk mode runs in WAL, which allows many concurrent readers alongside one
writer. Giving each thread its own read connection (rather than sharing the
single write connection under ``_lock``) means a slow ``SELECT`` no longer
blocks writers (loads/upserts) or other readers. In-memory mode can't do
this — each ``:memory:`` connection is a separate database — so it keeps
using the single locked connection.
"""
conn = getattr(self._read_local, "conn", None)
if conn is None:
conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
conn.execute("PRAGMA query_only=ON") # read-only guard
self._read_local.conn = conn
with self._lock:
self._read_conns.append(conn)
return conn
def execute_in_memory(
self, sql: str, params: tuple | list | dict | None = None
) -> tuple[list[str], list[tuple]]:
"""Run a read query against the cache.
In-memory mode serializes with writers on the single connection. Disk mode
reads from a per-thread WAL connection, so reads run concurrently with
writers and each other (see :meth:`_read_conn`).
"""
bound = coerce_params(params)
if self._in_memory:
with self._lock:
cursor = (
self._conn.execute(sql)
if bound is None
else self._conn.execute(sql, bound)
)
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return col_names, rows
conn = self._read_conn()
cursor = conn.execute(sql) if bound is None else conn.execute(sql, bound)
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return col_names, rows
# --- delta refresh support ---------------------------------------------
def get_table_columns(self, table: str) -> list[str]:
"""Authoritative ordered column list of a cached table (via PRAGMA)."""
rows = self._conn.execute(f"PRAGMA table_info({quote(table)})").fetchall()
return [r[1] for r in rows]
def create_unique_index(self, table: str, key_columns: list[str]) -> None:
"""Create the unique index on *key_columns* that makes upsert-by-key work."""
cols = quote_list(key_columns)
index = quote(f"idx_{table}_pk")
with self._lock:
self._conn.execute(
f"CREATE UNIQUE INDEX IF NOT EXISTS {index} ON {quote(table)} ({cols})"
)
self._conn.commit()
def get_last_synced_at(self, table: str) -> str | None:
row = self._conn.execute(
"SELECT last_synced_at FROM _sqlmem_tables WHERE table_name = ?", (table,)
).fetchone()
return row[0] if row else None
def set_last_synced_at(self, table: str, value: str | None) -> None:
with self._lock:
self._conn.execute(
"UPDATE _sqlmem_tables SET last_synced_at = ? WHERE table_name = ?",
(value, table),
)
self._conn.commit()
def max_value(self, table: str, column: str) -> str | None:
"""Maximum value of *column* across cached rows (the delta watermark)."""
row = self._conn.execute(
f"SELECT MAX({quote(column)}) FROM {quote(table)}"
).fetchone()
return row[0] if row else None
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
"""Insert-or-replace one batch of *rows* by the table's unique key."""
col_list = quote_list(columns)
placeholders = ", ".join("?" * len(columns))
clean_rows = [coerce_row(row) for row in rows]
with self._lock:
self._conn.executemany(
f"INSERT OR REPLACE INTO {quote(table)} ({col_list}) VALUES ({placeholders})",
clean_rows,
)
self._conn.commit()
def count_rows(self, table: str) -> int:
row = self._conn.execute(f"SELECT COUNT(*) FROM {quote(table)}").fetchone()
return int(row[0]) if row else 0
def reset(self) -> None:
"""Wipe the entire cache — every cached table plus the on-disk data
(the file is deleted in memory mode, VACUUMed in place in disk mode)."""
logger.info("Resetting cache — dropping all cached tables.")
with self._lock:
user_tables = [
r[0]
for r in self._conn.execute(
"SELECT name FROM sqlite_master "
r"WHERE type = 'table' AND name NOT LIKE 'sqlite\_%' ESCAPE '\' "
r"AND name NOT LIKE '\_sqlmem\_%' ESCAPE '\'"
).fetchall()
]
for name in user_tables:
self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}")
self._conn.execute("DELETE FROM _sqlmem_tables")
self._conn.execute("DELETE FROM _sqlmem_columns")
self._conn.commit()
self._states.clear()
if self._in_memory:
try:
if self._db_path.exists():
self._db_path.unlink()
except OSError as e:
logger.error(f"Failed to delete cache file {self._db_path}: {e}")
else:
# The open connection *is* the file — drop tables persisted the wipe;
# VACUUM reclaims the freed pages on disk.
try:
with self._lock:
self._conn.execute("VACUUM")
except sqlite3.Error as e:
logger.error(f"Failed to VACUUM cache file {self._db_path}: {e}")
def close(self) -> None:
self._backup_to_disk()
self._closed = True
with self._lock:
for conn in self._read_conns:
try:
conn.close()
except sqlite3.Error:
pass
self._read_conns.clear()
self._conn.close()
def _now() -> str:
return datetime.now(timezone.utc).isoformat()