537 lines
22 KiB
Python
537 lines
22 KiB
Python
import atexit
|
|
import signal
|
|
import sqlite3
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from loguru import logger
|
|
|
|
import sqlmem._meta as _meta
|
|
from ._coerce import coerce_params, coerce_row
|
|
from ._sql import quote, quote_list, quote_source
|
|
from .config import FETCH_BATCH_SIZE, SQL_DIALECT
|
|
from .stats import TableState
|
|
|
|
SCHEMA_VERSION = 3
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _Index:
|
|
name: str
|
|
columns: tuple[str, ...]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TableError:
|
|
"""Most recent load/refresh failure for a table (see ``CacheManager.get_errors``)."""
|
|
|
|
message: str
|
|
at: str
|
|
consecutive: int
|
|
|
|
|
|
class CacheManager:
|
|
def __init__(
|
|
self,
|
|
db_path: Path,
|
|
backup_interval: int,
|
|
in_memory: bool = True,
|
|
dialect: str = SQL_DIALECT,
|
|
fetch_batch: int = FETCH_BATCH_SIZE,
|
|
) -> None:
|
|
self._db_path = db_path
|
|
self._backup_interval = backup_interval
|
|
self._in_memory = in_memory
|
|
self._dialect = dialect # source-DB dialect, for identifier quoting
|
|
self._fetch_batch = fetch_batch # rows fetched per source batch
|
|
self._lock = threading.Lock() # serializes connection access
|
|
self._load_lock = threading.Lock() # serializes full table loads
|
|
self._states: dict[str, str] = {} # table → live processing state
|
|
self._errors: dict[str, TableError] = {} # table → last load/refresh failure
|
|
self._error_total = 0 # process-wide failure counter
|
|
self._index_defs: dict[str, list[_Index]] = {} # table → secondary indexes
|
|
self._read_local = threading.local() # per-thread read conn (disk mode)
|
|
self._read_conns: list[sqlite3.Connection] = [] # read conns, for cleanup
|
|
self._closed = False
|
|
|
|
if in_memory:
|
|
self._conn = sqlite3.connect(":memory:", check_same_thread=False)
|
|
else:
|
|
# Disk-backed: query the on-disk file directly — no RAM copy, every
|
|
# write persists immediately, and the cache can exceed available RAM.
|
|
self._conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
|
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
self._conn.execute("PRAGMA synchronous=NORMAL")
|
|
self._discard_if_schema_mismatch()
|
|
|
|
self._ensure_meta_tables()
|
|
if in_memory:
|
|
self._load_from_disk()
|
|
self._drop_orphan_staging()
|
|
|
|
if in_memory:
|
|
self._start_backup_thread()
|
|
atexit.register(self._backup_to_disk)
|
|
signal.signal(signal.SIGTERM, self._on_sigterm)
|
|
else:
|
|
atexit.register(self.close)
|
|
|
|
@property
|
|
def connection(self) -> sqlite3.Connection:
|
|
return self._conn
|
|
|
|
def _ensure_meta_tables(self) -> None:
|
|
self._conn.executescript("""
|
|
CREATE TABLE IF NOT EXISTS _sqlmem_meta (
|
|
key TEXT PRIMARY KEY,
|
|
value TEXT NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS _sqlmem_tables (
|
|
table_name TEXT PRIMARY KEY,
|
|
last_refresh_at TEXT NOT NULL,
|
|
row_count INTEGER,
|
|
is_full INTEGER NOT NULL DEFAULT 0,
|
|
last_synced_at TEXT
|
|
);
|
|
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
|
|
table_name TEXT NOT NULL,
|
|
column_name TEXT NOT NULL,
|
|
PRIMARY KEY (table_name, column_name)
|
|
);
|
|
""")
|
|
self._conn.execute(
|
|
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
|
|
("app_version", _meta.__version__),
|
|
)
|
|
self._conn.execute(
|
|
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
|
|
("schema_version", str(SCHEMA_VERSION)),
|
|
)
|
|
self._conn.execute(
|
|
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
|
|
("created_at", _now()),
|
|
)
|
|
self._conn.commit()
|
|
|
|
def _discard_if_schema_mismatch(self) -> None:
|
|
"""Disk mode: wipe an existing cache file written by an incompatible schema.
|
|
|
|
In memory mode the equivalent check lives in :meth:`_load_from_disk`; here
|
|
we operate on the live on-disk connection, dropping every table so the
|
|
meta tables are recreated fresh by :meth:`_ensure_meta_tables`.
|
|
"""
|
|
meta_exists = self._conn.execute(
|
|
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = '_sqlmem_meta'"
|
|
).fetchone()
|
|
if not meta_exists:
|
|
return # fresh file — nothing to validate
|
|
|
|
row = self._conn.execute(
|
|
"SELECT value FROM _sqlmem_meta WHERE key = 'schema_version'"
|
|
).fetchone()
|
|
if row is not None and int(row[0]) == SCHEMA_VERSION:
|
|
return
|
|
|
|
logger.warning(
|
|
"Cache schema version mismatch — wiping on-disk cache, starting fresh."
|
|
)
|
|
names = [
|
|
r[0]
|
|
for r in self._conn.execute(
|
|
r"SELECT name FROM sqlite_master WHERE type = 'table' "
|
|
r"AND name NOT LIKE 'sqlite\_%' ESCAPE '\'"
|
|
).fetchall()
|
|
]
|
|
for name in names:
|
|
self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}")
|
|
self._conn.commit()
|
|
|
|
def _load_from_disk(self) -> None:
|
|
if not self._db_path.exists():
|
|
logger.info(f"No cache file found at {self._db_path}, starting fresh.")
|
|
return
|
|
|
|
logger.info(f"Loading cache from {self._db_path}")
|
|
disk_conn = sqlite3.connect(self._db_path)
|
|
try:
|
|
schema_version = disk_conn.execute(
|
|
"SELECT value FROM _sqlmem_meta WHERE key = 'schema_version'"
|
|
).fetchone()
|
|
if schema_version is None or int(schema_version[0]) != SCHEMA_VERSION:
|
|
logger.warning("Cache schema version mismatch — discarding cache file, starting fresh.")
|
|
disk_conn.close()
|
|
return
|
|
|
|
disk_conn.backup(self._conn)
|
|
logger.info("Cache loaded from disk successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load cache from disk: {e} — starting fresh.")
|
|
finally:
|
|
disk_conn.close()
|
|
|
|
def _drop_orphan_staging(self) -> None:
|
|
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
|
|
orphans = [
|
|
r[0]
|
|
for r in self._conn.execute(
|
|
r"SELECT name FROM sqlite_master "
|
|
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
|
|
).fetchall()
|
|
]
|
|
for name in orphans:
|
|
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
|
|
self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}")
|
|
if orphans:
|
|
self._conn.commit()
|
|
|
|
def _backup_to_disk(self) -> None:
|
|
if self._closed:
|
|
return
|
|
if not self._in_memory:
|
|
# Disk-backed: every write already lands on disk; just flush the WAL.
|
|
with self._lock:
|
|
self._conn.commit()
|
|
return
|
|
logger.info(f"Backing up cache to {self._db_path}")
|
|
try:
|
|
with self._lock:
|
|
disk_conn = sqlite3.connect(self._db_path)
|
|
self._conn.backup(disk_conn)
|
|
disk_conn.close()
|
|
logger.info("Cache backup complete.")
|
|
except Exception as e:
|
|
logger.error(f"Cache backup failed: {e}")
|
|
|
|
def _start_backup_thread(self) -> None:
|
|
def loop() -> None:
|
|
event = threading.Event()
|
|
while not event.wait(self._backup_interval):
|
|
self._backup_to_disk()
|
|
|
|
t = threading.Thread(target=loop, daemon=True, name="sqlmem-backup")
|
|
t.start()
|
|
logger.debug(f"Backup thread started (interval={self._backup_interval}s)")
|
|
|
|
def _on_sigterm(self, signum, frame) -> None:
|
|
logger.info("SIGTERM received — flushing cache to disk.")
|
|
self._backup_to_disk()
|
|
|
|
def mark_table_refreshed(self, table: str, row_count: int, full: bool = False) -> None:
|
|
with self._lock:
|
|
self._conn.execute(
|
|
"""
|
|
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count, is_full)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(table_name) DO UPDATE SET
|
|
last_refresh_at = excluded.last_refresh_at,
|
|
row_count = excluded.row_count,
|
|
is_full = excluded.is_full
|
|
""",
|
|
(table, _now(), row_count, int(full)),
|
|
)
|
|
self._conn.commit()
|
|
|
|
def is_table_cached(self, table: str) -> bool:
|
|
row = self._conn.execute(
|
|
"SELECT 1 FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
|
).fetchone()
|
|
return row is not None
|
|
|
|
def is_table_full(self, table: str) -> bool:
|
|
"""True if the whole table (all columns) is cached — a SELECT * cache hit."""
|
|
row = self._conn.execute(
|
|
"SELECT is_full FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
|
).fetchone()
|
|
return bool(row and row[0])
|
|
|
|
def seconds_since_refresh(self, table: str) -> float | None:
|
|
"""Age of a cached table in seconds, or None if it is not cached."""
|
|
row = self._conn.execute(
|
|
"SELECT last_refresh_at FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
|
).fetchone()
|
|
if not row or not row[0]:
|
|
return None
|
|
last = datetime.fromisoformat(row[0])
|
|
return (datetime.now(timezone.utc) - last).total_seconds()
|
|
|
|
def discover_columns(self, table: str, source_conn: sqlite3.Connection) -> list[str]:
|
|
"""Return all column names of *table* from the source DB without fetching rows."""
|
|
logger.debug(f"Discovering columns of {table!r} from source DB")
|
|
cursor = source_conn.execute(
|
|
f"SELECT * FROM {quote_source(table, self._dialect)} WHERE 1 = 0"
|
|
)
|
|
columns = [desc[0] for desc in cursor.description]
|
|
logger.debug(f"{table!r} has columns: {columns}")
|
|
return columns
|
|
|
|
def set_state(self, table: str, state: str) -> None:
|
|
self._states[table] = state
|
|
|
|
def get_states(self) -> dict[str, str]:
|
|
return dict(self._states)
|
|
|
|
def clear_state(self, table: str) -> None:
|
|
self._states.pop(table, None)
|
|
self._errors.pop(table, None)
|
|
|
|
def record_error(self, table: str, message: str) -> None:
|
|
"""Record a load/refresh failure for *table* (increments its failure streak)."""
|
|
prev = self._errors.get(table)
|
|
streak = (prev.consecutive if prev else 0) + 1
|
|
self._errors[table] = TableError(message=message, at=_now(), consecutive=streak)
|
|
self._error_total += 1
|
|
logger.debug(f"Recorded error for {table!r} (streak {streak}): {message}")
|
|
|
|
def record_success(self, table: str) -> None:
|
|
"""Reset *table*'s failure streak to 0 after a successful load/refresh."""
|
|
prev = self._errors.get(table)
|
|
if prev and prev.consecutive:
|
|
self._errors[table] = TableError(prev.message, prev.at, 0)
|
|
|
|
def get_errors(self) -> dict[str, TableError]:
|
|
return dict(self._errors)
|
|
|
|
@property
|
|
def error_total(self) -> int:
|
|
return self._error_total
|
|
|
|
def add_index(self, table: str, columns: list[str]) -> None:
|
|
"""Register a secondary index to (re)create on *columns* after each load."""
|
|
name = "sqlmem_idx_" + "_".join([table, *columns])
|
|
defs = self._index_defs.setdefault(table, [])
|
|
if all(d.name != name for d in defs):
|
|
defs.append(_Index(name=name, columns=tuple(columns)))
|
|
|
|
def _create_indexes(self, table: str, available: list[str]) -> None:
|
|
"""Create the registered secondary indexes whose columns are all cached."""
|
|
available_set = set(available)
|
|
for idx in self._index_defs.get(table, []):
|
|
if not set(idx.columns) <= available_set:
|
|
logger.warning(
|
|
f"Skipping index {idx.name!r}: columns {idx.columns} not all cached."
|
|
)
|
|
continue
|
|
cols = quote_list(idx.columns)
|
|
with self._lock:
|
|
self._conn.execute(
|
|
f"CREATE INDEX IF NOT EXISTS {quote(idx.name)} ON {quote(table)} ({cols})"
|
|
)
|
|
self._conn.commit()
|
|
logger.debug(f"Index {idx.name!r} ready on {table} ({cols})")
|
|
|
|
def load_table(
|
|
self,
|
|
table: str,
|
|
columns: list[str],
|
|
source_conn: sqlite3.Connection,
|
|
full: bool = False,
|
|
) -> None:
|
|
"""Stream the source table into the cache in batches.
|
|
|
|
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
|
|
table and swapped in atomically, so peak memory stays bounded (no
|
|
``fetchall`` of a huge table) and readers keep seeing the previous copy
|
|
until the swap. Concurrent loads are serialized by ``_load_lock``; the
|
|
connection lock is only held for the brief per-batch inserts and the swap.
|
|
"""
|
|
src_cols = ", ".join(quote_source(c, self._dialect) for c in columns)
|
|
col_defs = ", ".join(f"{quote(c)} TEXT" for c in columns)
|
|
placeholders = ", ".join("?" * len(columns))
|
|
staging = f"{table}__sqlmem_load"
|
|
q_staging = quote(staging)
|
|
q_table = quote(table)
|
|
|
|
with self._load_lock:
|
|
self.set_state(table, TableState.LOADING)
|
|
logger.info(f"Fetching {table!r} columns {columns} from source DB (batch={self._fetch_batch})")
|
|
try:
|
|
cursor = source_conn.execute(
|
|
f"SELECT {src_cols} FROM {quote_source(table, self._dialect)}"
|
|
)
|
|
with self._lock:
|
|
self._conn.execute(f"DROP TABLE IF EXISTS {q_staging}")
|
|
self._conn.execute(f"CREATE TABLE {q_staging} ({col_defs})")
|
|
self._conn.commit()
|
|
|
|
total = 0
|
|
insert_sql = f"INSERT INTO {q_staging} VALUES ({placeholders})"
|
|
while True:
|
|
batch = cursor.fetchmany(self._fetch_batch) # network outside _lock
|
|
if not batch:
|
|
break
|
|
clean = [coerce_row(row) for row in batch]
|
|
with self._lock:
|
|
self._conn.executemany(insert_sql, clean)
|
|
self._conn.commit()
|
|
total += len(batch)
|
|
|
|
with self._lock: # atomic swap — readers see old or new, never partial
|
|
self._conn.execute(f"DROP TABLE IF EXISTS {q_table}")
|
|
self._conn.execute(f"ALTER TABLE {q_staging} RENAME TO {q_table}")
|
|
self._conn.commit()
|
|
except BaseException as exc:
|
|
with self._lock:
|
|
self._conn.execute(f"DROP TABLE IF EXISTS {q_staging}")
|
|
self._conn.commit()
|
|
self.set_state(table, TableState.ERROR)
|
|
self.record_error(table, f"{type(exc).__name__}: {exc}")
|
|
raise
|
|
|
|
self._create_indexes(table, columns)
|
|
self.mark_table_refreshed(table, total, full)
|
|
self.set_state(table, TableState.READY)
|
|
self.record_success(table)
|
|
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
|
|
|
|
def _read_conn(self) -> sqlite3.Connection:
|
|
"""A per-thread, read-only connection used for cache reads in disk mode.
|
|
|
|
Disk mode runs in WAL, which allows many concurrent readers alongside one
|
|
writer. Giving each thread its own read connection (rather than sharing the
|
|
single write connection under ``_lock``) means a slow ``SELECT`` no longer
|
|
blocks writers (loads/upserts) or other readers. In-memory mode can't do
|
|
this — each ``:memory:`` connection is a separate database — so it keeps
|
|
using the single locked connection.
|
|
"""
|
|
conn = getattr(self._read_local, "conn", None)
|
|
if conn is None:
|
|
conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
|
|
conn.execute("PRAGMA query_only=ON") # read-only guard
|
|
self._read_local.conn = conn
|
|
with self._lock:
|
|
self._read_conns.append(conn)
|
|
return conn
|
|
|
|
def execute_in_memory(
|
|
self, sql: str, params: tuple | list | dict | None = None
|
|
) -> tuple[list[str], list[tuple]]:
|
|
"""Run a read query against the cache.
|
|
|
|
In-memory mode serializes with writers on the single connection. Disk mode
|
|
reads from a per-thread WAL connection, so reads run concurrently with
|
|
writers and each other (see :meth:`_read_conn`).
|
|
"""
|
|
bound = coerce_params(params)
|
|
if self._in_memory:
|
|
with self._lock:
|
|
cursor = (
|
|
self._conn.execute(sql)
|
|
if bound is None
|
|
else self._conn.execute(sql, bound)
|
|
)
|
|
col_names = [desc[0] for desc in cursor.description]
|
|
rows = cursor.fetchall()
|
|
return col_names, rows
|
|
|
|
conn = self._read_conn()
|
|
cursor = conn.execute(sql) if bound is None else conn.execute(sql, bound)
|
|
col_names = [desc[0] for desc in cursor.description]
|
|
rows = cursor.fetchall()
|
|
return col_names, rows
|
|
|
|
# --- delta refresh support ---------------------------------------------
|
|
|
|
def get_table_columns(self, table: str) -> list[str]:
|
|
"""Authoritative ordered column list of a cached table (via PRAGMA)."""
|
|
rows = self._conn.execute(f"PRAGMA table_info({quote(table)})").fetchall()
|
|
return [r[1] for r in rows]
|
|
|
|
def create_unique_index(self, table: str, key_columns: list[str]) -> None:
|
|
"""Create the unique index on *key_columns* that makes upsert-by-key work."""
|
|
cols = quote_list(key_columns)
|
|
index = quote(f"idx_{table}_pk")
|
|
with self._lock:
|
|
self._conn.execute(
|
|
f"CREATE UNIQUE INDEX IF NOT EXISTS {index} ON {quote(table)} ({cols})"
|
|
)
|
|
self._conn.commit()
|
|
|
|
def get_last_synced_at(self, table: str) -> str | None:
|
|
row = self._conn.execute(
|
|
"SELECT last_synced_at FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
|
).fetchone()
|
|
return row[0] if row else None
|
|
|
|
def set_last_synced_at(self, table: str, value: str | None) -> None:
|
|
with self._lock:
|
|
self._conn.execute(
|
|
"UPDATE _sqlmem_tables SET last_synced_at = ? WHERE table_name = ?",
|
|
(value, table),
|
|
)
|
|
self._conn.commit()
|
|
|
|
def max_value(self, table: str, column: str) -> str | None:
|
|
"""Maximum value of *column* across cached rows (the delta watermark)."""
|
|
row = self._conn.execute(
|
|
f"SELECT MAX({quote(column)}) FROM {quote(table)}"
|
|
).fetchone()
|
|
return row[0] if row else None
|
|
|
|
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
|
|
"""Insert-or-replace one batch of *rows* by the table's unique key."""
|
|
col_list = quote_list(columns)
|
|
placeholders = ", ".join("?" * len(columns))
|
|
clean_rows = [coerce_row(row) for row in rows]
|
|
with self._lock:
|
|
self._conn.executemany(
|
|
f"INSERT OR REPLACE INTO {quote(table)} ({col_list}) VALUES ({placeholders})",
|
|
clean_rows,
|
|
)
|
|
self._conn.commit()
|
|
|
|
def count_rows(self, table: str) -> int:
|
|
row = self._conn.execute(f"SELECT COUNT(*) FROM {quote(table)}").fetchone()
|
|
return int(row[0]) if row else 0
|
|
|
|
def reset(self) -> None:
|
|
"""Wipe the entire cache — every cached table plus the on-disk data
|
|
(the file is deleted in memory mode, VACUUMed in place in disk mode)."""
|
|
logger.info("Resetting cache — dropping all cached tables.")
|
|
with self._lock:
|
|
user_tables = [
|
|
r[0]
|
|
for r in self._conn.execute(
|
|
"SELECT name FROM sqlite_master "
|
|
r"WHERE type = 'table' AND name NOT LIKE 'sqlite\_%' ESCAPE '\' "
|
|
r"AND name NOT LIKE '\_sqlmem\_%' ESCAPE '\'"
|
|
).fetchall()
|
|
]
|
|
for name in user_tables:
|
|
self._conn.execute(f"DROP TABLE IF EXISTS {quote(name)}")
|
|
self._conn.execute("DELETE FROM _sqlmem_tables")
|
|
self._conn.execute("DELETE FROM _sqlmem_columns")
|
|
self._conn.commit()
|
|
self._states.clear()
|
|
if self._in_memory:
|
|
try:
|
|
if self._db_path.exists():
|
|
self._db_path.unlink()
|
|
except OSError as e:
|
|
logger.error(f"Failed to delete cache file {self._db_path}: {e}")
|
|
else:
|
|
# The open connection *is* the file — drop tables persisted the wipe;
|
|
# VACUUM reclaims the freed pages on disk.
|
|
try:
|
|
with self._lock:
|
|
self._conn.execute("VACUUM")
|
|
except sqlite3.Error as e:
|
|
logger.error(f"Failed to VACUUM cache file {self._db_path}: {e}")
|
|
|
|
def close(self) -> None:
|
|
self._backup_to_disk()
|
|
self._closed = True
|
|
with self._lock:
|
|
for conn in self._read_conns:
|
|
try:
|
|
conn.close()
|
|
except sqlite3.Error:
|
|
pass
|
|
self._read_conns.clear()
|
|
self._conn.close()
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|