Batch large-table loads to bound memory and add per-table state to stats

This commit is contained in:
Jan Doubravský
2026-06-05 14:44:07 +02:00
parent 85bb84a1a6
commit 286a5f207d
11 changed files with 436 additions and 29 deletions
+82 -16
View File
@@ -9,6 +9,8 @@ from loguru import logger
import sqlmem._meta as _meta
from ._coerce import coerce_params, coerce_row
from .config import FETCH_BATCH_SIZE
from .stats import TableState
SCHEMA_VERSION = 3
@@ -18,11 +20,14 @@ class CacheManager:
self._db_path = db_path
self._backup_interval = backup_interval
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
self._lock = threading.Lock()
self._lock = threading.Lock() # serializes connection access
self._load_lock = threading.Lock() # serializes full table loads
self._states: dict[str, str] = {} # table → live processing state
self._closed = False
self._ensure_meta_tables()
self._load_from_disk()
self._drop_orphan_staging()
self._start_backup_thread()
atexit.register(self._backup_to_disk)
@@ -88,6 +93,21 @@ class CacheManager:
finally:
disk_conn.close()
def _drop_orphan_staging(self) -> None:
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
orphans = [
r[0]
for r in self._mem_conn.execute(
r"SELECT name FROM sqlite_master "
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
).fetchall()
]
for name in orphans:
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
if orphans:
self._mem_conn.commit()
def _backup_to_disk(self) -> None:
if self._closed:
return
@@ -161,6 +181,15 @@ class CacheManager:
logger.debug(f"{table!r} has columns: {columns}")
return columns
def set_state(self, table: str, state: str) -> None:
self._states[table] = state
def get_states(self) -> dict[str, str]:
return dict(self._states)
def clear_state(self, table: str) -> None:
self._states.pop(table, None)
def load_table(
self,
table: str,
@@ -168,21 +197,55 @@ class CacheManager:
source_conn: sqlite3.Connection,
full: bool = False,
) -> None:
"""Stream the source table into the cache in batches.
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
table and swapped in atomically, so peak memory stays bounded (no
``fetchall`` of a huge table) and readers keep seeing the previous copy
until the swap. Concurrent loads are serialized by ``_load_lock``; the
connection lock is only held for the brief per-batch inserts and the swap.
"""
cols = ", ".join(columns)
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
clean_rows = [coerce_row(row) for row in rows]
col_defs = ", ".join(f"{c} TEXT" for c in columns)
placeholders = ", ".join("?" * len(columns))
staging = f"{table}__sqlmem_load"
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
col_defs = ", ".join(f"{c} TEXT" for c in columns)
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
placeholders = ", ".join("?" * len(columns))
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows)
self._mem_conn.commit()
with self._load_lock:
self.set_state(table, TableState.LOADING)
logger.info(f"Fetching {table!r} columns [{cols}] from source DB (batch={FETCH_BATCH_SIZE})")
try:
cursor = source_conn.execute(f"SELECT {cols} FROM {table}")
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
self._mem_conn.execute(f"CREATE TABLE {staging} ({col_defs})")
self._mem_conn.commit()
self.mark_table_refreshed(table, len(rows), full)
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
total = 0
insert_sql = f"INSERT INTO {staging} VALUES ({placeholders})"
while True:
batch = cursor.fetchmany(FETCH_BATCH_SIZE) # network outside _lock
if not batch:
break
clean = [coerce_row(row) for row in batch]
with self._lock:
self._mem_conn.executemany(insert_sql, clean)
self._mem_conn.commit()
total += len(batch)
with self._lock: # atomic swap — readers see old or new, never partial
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
self._mem_conn.execute(f"ALTER TABLE {staging} RENAME TO {table}")
self._mem_conn.commit()
except BaseException:
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
self._mem_conn.commit()
self.set_state(table, TableState.ERROR)
raise
self.mark_table_refreshed(table, total, full)
self.set_state(table, TableState.READY)
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
def execute_in_memory(
self, sql: str, params: tuple | list | dict | None = None
@@ -232,7 +295,7 @@ class CacheManager:
return row[0] if row else None
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count."""
"""Insert-or-replace one batch of *rows* by the table's unique key."""
col_list = ", ".join(columns)
placeholders = ", ".join("?" * len(columns))
clean_rows = [coerce_row(row) for row in rows]
@@ -242,8 +305,10 @@ class CacheManager:
clean_rows,
)
self._mem_conn.commit()
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
self.mark_table_refreshed(table, count, self.is_table_full(table))
def count_rows(self, table: str) -> int:
row = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
return int(row[0]) if row else 0
def reset(self) -> None:
"""Wipe the entire cache — every cached table plus the on-disk file."""
@@ -262,6 +327,7 @@ class CacheManager:
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
self._mem_conn.commit()
self._states.clear()
try:
if self._db_path.exists():
self._db_path.unlink()
+2
View File
@@ -11,6 +11,8 @@ CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
# How often (seconds) the background thread pulls deltas for delta-tracked tables.
REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300"))
# Rows fetched per batch when loading a table — caps peak memory for huge tables.
FETCH_BATCH_SIZE = int(os.getenv("SQLMEM_FETCH_BATCH", "10000"))
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
+24 -6
View File
@@ -4,6 +4,8 @@ from dataclasses import dataclass, field
from loguru import logger
from .cache import CacheManager
from .config import FETCH_BATCH_SIZE
from .stats import TableState
@dataclass(frozen=True)
@@ -58,21 +60,37 @@ class DeltaRefresher:
col_list = ", ".join(columns)
if watermark is None:
rows = source_conn.execute(f"SELECT {col_list} FROM {table}").fetchall()
cursor = source_conn.execute(f"SELECT {col_list} FROM {table}")
else:
rows = source_conn.execute(
cursor = source_conn.execute(
f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?",
(watermark,),
).fetchall()
)
if not rows:
# Stream the delta in batches so a large catch-up never materializes at once.
total = 0
self._cache.set_state(table, TableState.REFRESHING)
try:
while True:
batch = cursor.fetchmany(FETCH_BATCH_SIZE)
if not batch:
break
self._cache.upsert_rows(table, columns, batch)
total += len(batch)
finally:
self._cache.set_state(table, TableState.READY)
if total == 0:
logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}")
return
self._cache.upsert_rows(table, columns, rows)
# Update row_count / last_refresh once (not per batch) and advance the watermark.
self._cache.mark_table_refreshed(
table, self._cache.count_rows(table), self._cache.is_table_full(table)
)
new_watermark = self._cache.max_value(table, cfg.change_column)
self._cache.set_last_synced_at(table, new_watermark)
logger.info(
f"Delta refresh {table!r}: {len(rows)} row(s) upserted, "
f"Delta refresh {table!r}: {total} row(s) upserted, "
f"watermark {watermark!r}{new_watermark!r}"
)
+22 -2
View File
@@ -1,5 +1,6 @@
import sqlite3
import threading
from dataclasses import replace
from typing import cast
from loguru import logger
@@ -12,7 +13,7 @@ from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
from .executor import QueryExecutor
from .parser import Params, parse
from .registry import ColumnRegistry
from .stats import Stats, StatsCollector
from .stats import Stats, StatsCollector, TableState, TableStats
class CachingEngine:
@@ -68,8 +69,26 @@ class CachingEngine:
@property
def stats(self) -> Stats:
states = self._cache.get_states()
with self._cache._lock:
return self._stats.snapshot(self._cache.connection)
base = self._stats.snapshot(self._cache.connection, states)
return replace(base, tables={n: self._enrich(n, t) for n, t in base.tables.items()})
def _enrich(self, name: str, table_stats: TableStats) -> TableStats:
"""Annotate a TableStats with how it is refreshed and TTL staleness."""
if name in self._delta:
tracking = "delta"
elif name in self._ttl:
tracking = "ttl"
else:
tracking = "static"
state = table_stats.state
if state == TableState.READY and name in self._ttl:
age = self._cache.seconds_since_refresh(name)
if age is not None and age > self._ttl[name]:
state = TableState.STALE
return replace(table_stats, tracking=tracking, state=state)
def execute(self, sql: str, params: Params = None) -> list[dict]:
parsed = parse(sql, params)
@@ -130,6 +149,7 @@ class CachingEngine:
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
)
self._cache.connection.commit()
self._cache.clear_state(table)
def reset(self) -> None:
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
+24 -1
View File
@@ -3,11 +3,23 @@ import threading
from dataclasses import dataclass
class TableState:
"""Live processing state of a cached table (value of ``TableStats.state``)."""
LOADING = "loading" # a full load is in progress
REFRESHING = "refreshing" # an incremental (delta) refresh is in progress
READY = "ready" # cached and idle
STALE = "stale" # TTL expired — will reload on next access
ERROR = "error" # the last load failed
@dataclass(frozen=True)
class TableStats:
rows: int
columns: list[str]
last_refresh: str
state: str = TableState.READY
tracking: str = "static" # "delta" | "ttl" | "static"
@dataclass(frozen=True)
@@ -37,14 +49,19 @@ class StatsCollector:
with self._lock:
self.refetches += 1
def snapshot(self, conn: sqlite3.Connection) -> Stats:
def snapshot(
self, conn: sqlite3.Connection, states: dict[str, str] | None = None
) -> Stats:
states = states or {}
with self._lock:
hits, misses, refetches = self.hits, self.misses, self.refetches
tables: dict[str, TableStats] = {}
cached: set[str] = set()
for table_name, row_count, last_refresh in conn.execute(
"SELECT table_name, row_count, last_refresh_at FROM _sqlmem_tables"
).fetchall():
cached.add(table_name)
columns = [
r[0]
for r in conn.execute(
@@ -56,6 +73,12 @@ class StatsCollector:
rows=row_count or 0,
columns=columns,
last_refresh=last_refresh,
state=states.get(table_name, TableState.READY),
)
# Surface tables that are mid-first-load (not yet in _sqlmem_tables) or failed.
for name, state in states.items():
if name not in cached and state in (TableState.LOADING, TableState.ERROR):
tables[name] = TableStats(rows=0, columns=[], last_refresh="", state=state)
return Stats(hits=hits, misses=misses, refetches=refetches, tables=tables)