Batch large-table loads to bound memory and add per-table state to stats
This commit is contained in:
+82
-16
@@ -9,6 +9,8 @@ from loguru import logger
|
||||
|
||||
import sqlmem._meta as _meta
|
||||
from ._coerce import coerce_params, coerce_row
|
||||
from .config import FETCH_BATCH_SIZE
|
||||
from .stats import TableState
|
||||
|
||||
SCHEMA_VERSION = 3
|
||||
|
||||
@@ -18,11 +20,14 @@ class CacheManager:
|
||||
self._db_path = db_path
|
||||
self._backup_interval = backup_interval
|
||||
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
self._lock = threading.Lock()
|
||||
self._lock = threading.Lock() # serializes connection access
|
||||
self._load_lock = threading.Lock() # serializes full table loads
|
||||
self._states: dict[str, str] = {} # table → live processing state
|
||||
self._closed = False
|
||||
|
||||
self._ensure_meta_tables()
|
||||
self._load_from_disk()
|
||||
self._drop_orphan_staging()
|
||||
self._start_backup_thread()
|
||||
|
||||
atexit.register(self._backup_to_disk)
|
||||
@@ -88,6 +93,21 @@ class CacheManager:
|
||||
finally:
|
||||
disk_conn.close()
|
||||
|
||||
def _drop_orphan_staging(self) -> None:
|
||||
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
|
||||
orphans = [
|
||||
r[0]
|
||||
for r in self._mem_conn.execute(
|
||||
r"SELECT name FROM sqlite_master "
|
||||
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
|
||||
).fetchall()
|
||||
]
|
||||
for name in orphans:
|
||||
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
|
||||
if orphans:
|
||||
self._mem_conn.commit()
|
||||
|
||||
def _backup_to_disk(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
@@ -161,6 +181,15 @@ class CacheManager:
|
||||
logger.debug(f"{table!r} has columns: {columns}")
|
||||
return columns
|
||||
|
||||
def set_state(self, table: str, state: str) -> None:
|
||||
self._states[table] = state
|
||||
|
||||
def get_states(self) -> dict[str, str]:
|
||||
return dict(self._states)
|
||||
|
||||
def clear_state(self, table: str) -> None:
|
||||
self._states.pop(table, None)
|
||||
|
||||
def load_table(
|
||||
self,
|
||||
table: str,
|
||||
@@ -168,21 +197,55 @@ class CacheManager:
|
||||
source_conn: sqlite3.Connection,
|
||||
full: bool = False,
|
||||
) -> None:
|
||||
"""Stream the source table into the cache in batches.
|
||||
|
||||
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
|
||||
table and swapped in atomically, so peak memory stays bounded (no
|
||||
``fetchall`` of a huge table) and readers keep seeing the previous copy
|
||||
until the swap. Concurrent loads are serialized by ``_load_lock``; the
|
||||
connection lock is only held for the brief per-batch inserts and the swap.
|
||||
"""
|
||||
cols = ", ".join(columns)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
|
||||
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
|
||||
clean_rows = [coerce_row(row) for row in rows]
|
||||
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
staging = f"{table}__sqlmem_load"
|
||||
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
||||
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows)
|
||||
self._mem_conn.commit()
|
||||
with self._load_lock:
|
||||
self.set_state(table, TableState.LOADING)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB (batch={FETCH_BATCH_SIZE})")
|
||||
try:
|
||||
cursor = source_conn.execute(f"SELECT {cols} FROM {table}")
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
|
||||
self._mem_conn.execute(f"CREATE TABLE {staging} ({col_defs})")
|
||||
self._mem_conn.commit()
|
||||
|
||||
self.mark_table_refreshed(table, len(rows), full)
|
||||
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
|
||||
total = 0
|
||||
insert_sql = f"INSERT INTO {staging} VALUES ({placeholders})"
|
||||
while True:
|
||||
batch = cursor.fetchmany(FETCH_BATCH_SIZE) # network outside _lock
|
||||
if not batch:
|
||||
break
|
||||
clean = [coerce_row(row) for row in batch]
|
||||
with self._lock:
|
||||
self._mem_conn.executemany(insert_sql, clean)
|
||||
self._mem_conn.commit()
|
||||
total += len(batch)
|
||||
|
||||
with self._lock: # atomic swap — readers see old or new, never partial
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
self._mem_conn.execute(f"ALTER TABLE {staging} RENAME TO {table}")
|
||||
self._mem_conn.commit()
|
||||
except BaseException:
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
|
||||
self._mem_conn.commit()
|
||||
self.set_state(table, TableState.ERROR)
|
||||
raise
|
||||
|
||||
self.mark_table_refreshed(table, total, full)
|
||||
self.set_state(table, TableState.READY)
|
||||
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
|
||||
|
||||
def execute_in_memory(
|
||||
self, sql: str, params: tuple | list | dict | None = None
|
||||
@@ -232,7 +295,7 @@ class CacheManager:
|
||||
return row[0] if row else None
|
||||
|
||||
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
|
||||
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count."""
|
||||
"""Insert-or-replace one batch of *rows* by the table's unique key."""
|
||||
col_list = ", ".join(columns)
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
clean_rows = [coerce_row(row) for row in rows]
|
||||
@@ -242,8 +305,10 @@ class CacheManager:
|
||||
clean_rows,
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
|
||||
self.mark_table_refreshed(table, count, self.is_table_full(table))
|
||||
|
||||
def count_rows(self, table: str) -> int:
|
||||
row = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Wipe the entire cache — every cached table plus the on-disk file."""
|
||||
@@ -262,6 +327,7 @@ class CacheManager:
|
||||
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
|
||||
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
|
||||
self._mem_conn.commit()
|
||||
self._states.clear()
|
||||
try:
|
||||
if self._db_path.exists():
|
||||
self._db_path.unlink()
|
||||
|
||||
@@ -11,6 +11,8 @@ CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
|
||||
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
|
||||
# How often (seconds) the background thread pulls deltas for delta-tracked tables.
|
||||
REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300"))
|
||||
# Rows fetched per batch when loading a table — caps peak memory for huge tables.
|
||||
FETCH_BATCH_SIZE = int(os.getenv("SQLMEM_FETCH_BATCH", "10000"))
|
||||
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
|
||||
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
|
||||
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
|
||||
|
||||
+24
-6
@@ -4,6 +4,8 @@ from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
from .cache import CacheManager
|
||||
from .config import FETCH_BATCH_SIZE
|
||||
from .stats import TableState
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -58,21 +60,37 @@ class DeltaRefresher:
|
||||
col_list = ", ".join(columns)
|
||||
|
||||
if watermark is None:
|
||||
rows = source_conn.execute(f"SELECT {col_list} FROM {table}").fetchall()
|
||||
cursor = source_conn.execute(f"SELECT {col_list} FROM {table}")
|
||||
else:
|
||||
rows = source_conn.execute(
|
||||
cursor = source_conn.execute(
|
||||
f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?",
|
||||
(watermark,),
|
||||
).fetchall()
|
||||
)
|
||||
|
||||
if not rows:
|
||||
# Stream the delta in batches so a large catch-up never materializes at once.
|
||||
total = 0
|
||||
self._cache.set_state(table, TableState.REFRESHING)
|
||||
try:
|
||||
while True:
|
||||
batch = cursor.fetchmany(FETCH_BATCH_SIZE)
|
||||
if not batch:
|
||||
break
|
||||
self._cache.upsert_rows(table, columns, batch)
|
||||
total += len(batch)
|
||||
finally:
|
||||
self._cache.set_state(table, TableState.READY)
|
||||
|
||||
if total == 0:
|
||||
logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}")
|
||||
return
|
||||
|
||||
self._cache.upsert_rows(table, columns, rows)
|
||||
# Update row_count / last_refresh once (not per batch) and advance the watermark.
|
||||
self._cache.mark_table_refreshed(
|
||||
table, self._cache.count_rows(table), self._cache.is_table_full(table)
|
||||
)
|
||||
new_watermark = self._cache.max_value(table, cfg.change_column)
|
||||
self._cache.set_last_synced_at(table, new_watermark)
|
||||
logger.info(
|
||||
f"Delta refresh {table!r}: {len(rows)} row(s) upserted, "
|
||||
f"Delta refresh {table!r}: {total} row(s) upserted, "
|
||||
f"watermark {watermark!r} → {new_watermark!r}"
|
||||
)
|
||||
|
||||
+22
-2
@@ -1,5 +1,6 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import replace
|
||||
from typing import cast
|
||||
|
||||
from loguru import logger
|
||||
@@ -12,7 +13,7 @@ from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
|
||||
from .executor import QueryExecutor
|
||||
from .parser import Params, parse
|
||||
from .registry import ColumnRegistry
|
||||
from .stats import Stats, StatsCollector
|
||||
from .stats import Stats, StatsCollector, TableState, TableStats
|
||||
|
||||
|
||||
class CachingEngine:
|
||||
@@ -68,8 +69,26 @@ class CachingEngine:
|
||||
|
||||
@property
|
||||
def stats(self) -> Stats:
|
||||
states = self._cache.get_states()
|
||||
with self._cache._lock:
|
||||
return self._stats.snapshot(self._cache.connection)
|
||||
base = self._stats.snapshot(self._cache.connection, states)
|
||||
return replace(base, tables={n: self._enrich(n, t) for n, t in base.tables.items()})
|
||||
|
||||
def _enrich(self, name: str, table_stats: TableStats) -> TableStats:
|
||||
"""Annotate a TableStats with how it is refreshed and TTL staleness."""
|
||||
if name in self._delta:
|
||||
tracking = "delta"
|
||||
elif name in self._ttl:
|
||||
tracking = "ttl"
|
||||
else:
|
||||
tracking = "static"
|
||||
|
||||
state = table_stats.state
|
||||
if state == TableState.READY and name in self._ttl:
|
||||
age = self._cache.seconds_since_refresh(name)
|
||||
if age is not None and age > self._ttl[name]:
|
||||
state = TableState.STALE
|
||||
return replace(table_stats, tracking=tracking, state=state)
|
||||
|
||||
def execute(self, sql: str, params: Params = None) -> list[dict]:
|
||||
parsed = parse(sql, params)
|
||||
@@ -130,6 +149,7 @@ class CachingEngine:
|
||||
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
|
||||
)
|
||||
self._cache.connection.commit()
|
||||
self._cache.clear_state(table)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
|
||||
|
||||
+24
-1
@@ -3,11 +3,23 @@ import threading
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TableState:
|
||||
"""Live processing state of a cached table (value of ``TableStats.state``)."""
|
||||
|
||||
LOADING = "loading" # a full load is in progress
|
||||
REFRESHING = "refreshing" # an incremental (delta) refresh is in progress
|
||||
READY = "ready" # cached and idle
|
||||
STALE = "stale" # TTL expired — will reload on next access
|
||||
ERROR = "error" # the last load failed
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TableStats:
|
||||
rows: int
|
||||
columns: list[str]
|
||||
last_refresh: str
|
||||
state: str = TableState.READY
|
||||
tracking: str = "static" # "delta" | "ttl" | "static"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -37,14 +49,19 @@ class StatsCollector:
|
||||
with self._lock:
|
||||
self.refetches += 1
|
||||
|
||||
def snapshot(self, conn: sqlite3.Connection) -> Stats:
|
||||
def snapshot(
|
||||
self, conn: sqlite3.Connection, states: dict[str, str] | None = None
|
||||
) -> Stats:
|
||||
states = states or {}
|
||||
with self._lock:
|
||||
hits, misses, refetches = self.hits, self.misses, self.refetches
|
||||
|
||||
tables: dict[str, TableStats] = {}
|
||||
cached: set[str] = set()
|
||||
for table_name, row_count, last_refresh in conn.execute(
|
||||
"SELECT table_name, row_count, last_refresh_at FROM _sqlmem_tables"
|
||||
).fetchall():
|
||||
cached.add(table_name)
|
||||
columns = [
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
@@ -56,6 +73,12 @@ class StatsCollector:
|
||||
rows=row_count or 0,
|
||||
columns=columns,
|
||||
last_refresh=last_refresh,
|
||||
state=states.get(table_name, TableState.READY),
|
||||
)
|
||||
|
||||
# Surface tables that are mid-first-load (not yet in _sqlmem_tables) or failed.
|
||||
for name, state in states.items():
|
||||
if name not in cached and state in (TableState.LOADING, TableState.ERROR):
|
||||
tables[name] = TableStats(rows=0, columns=[], last_refresh="", state=state)
|
||||
|
||||
return Stats(hits=hits, misses=misses, refetches=refetches, tables=tables)
|
||||
|
||||
Reference in New Issue
Block a user