Batch large-table loads to bound memory and add per-table state to stats

This commit is contained in:
Jan Doubravský
2026-06-05 14:44:07 +02:00
parent 85bb84a1a6
commit 286a5f207d
11 changed files with 436 additions and 29 deletions
+82 -16
View File
@@ -9,6 +9,8 @@ from loguru import logger
import sqlmem._meta as _meta
from ._coerce import coerce_params, coerce_row
from .config import FETCH_BATCH_SIZE
from .stats import TableState
SCHEMA_VERSION = 3
@@ -18,11 +20,14 @@ class CacheManager:
self._db_path = db_path
self._backup_interval = backup_interval
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
self._lock = threading.Lock()
self._lock = threading.Lock() # serializes connection access
self._load_lock = threading.Lock() # serializes full table loads
self._states: dict[str, str] = {} # table → live processing state
self._closed = False
self._ensure_meta_tables()
self._load_from_disk()
self._drop_orphan_staging()
self._start_backup_thread()
atexit.register(self._backup_to_disk)
@@ -88,6 +93,21 @@ class CacheManager:
finally:
disk_conn.close()
def _drop_orphan_staging(self) -> None:
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
orphans = [
r[0]
for r in self._mem_conn.execute(
r"SELECT name FROM sqlite_master "
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
).fetchall()
]
for name in orphans:
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
if orphans:
self._mem_conn.commit()
def _backup_to_disk(self) -> None:
if self._closed:
return
@@ -161,6 +181,15 @@ class CacheManager:
logger.debug(f"{table!r} has columns: {columns}")
return columns
def set_state(self, table: str, state: str) -> None:
self._states[table] = state
def get_states(self) -> dict[str, str]:
return dict(self._states)
def clear_state(self, table: str) -> None:
self._states.pop(table, None)
def load_table(
self,
table: str,
@@ -168,21 +197,55 @@ class CacheManager:
source_conn: sqlite3.Connection,
full: bool = False,
) -> None:
"""Stream the source table into the cache in batches.
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
table and swapped in atomically, so peak memory stays bounded (no
``fetchall`` of a huge table) and readers keep seeing the previous copy
until the swap. Concurrent loads are serialized by ``_load_lock``; the
connection lock is only held for the brief per-batch inserts and the swap.
"""
cols = ", ".join(columns)
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
clean_rows = [coerce_row(row) for row in rows]
col_defs = ", ".join(f"{c} TEXT" for c in columns)
placeholders = ", ".join("?" * len(columns))
staging = f"{table}__sqlmem_load"
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
col_defs = ", ".join(f"{c} TEXT" for c in columns)
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
placeholders = ", ".join("?" * len(columns))
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows)
self._mem_conn.commit()
with self._load_lock:
self.set_state(table, TableState.LOADING)
logger.info(f"Fetching {table!r} columns [{cols}] from source DB (batch={FETCH_BATCH_SIZE})")
try:
cursor = source_conn.execute(f"SELECT {cols} FROM {table}")
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
self._mem_conn.execute(f"CREATE TABLE {staging} ({col_defs})")
self._mem_conn.commit()
self.mark_table_refreshed(table, len(rows), full)
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
total = 0
insert_sql = f"INSERT INTO {staging} VALUES ({placeholders})"
while True:
batch = cursor.fetchmany(FETCH_BATCH_SIZE) # network outside _lock
if not batch:
break
clean = [coerce_row(row) for row in batch]
with self._lock:
self._mem_conn.executemany(insert_sql, clean)
self._mem_conn.commit()
total += len(batch)
with self._lock: # atomic swap — readers see old or new, never partial
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
self._mem_conn.execute(f"ALTER TABLE {staging} RENAME TO {table}")
self._mem_conn.commit()
except BaseException:
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
self._mem_conn.commit()
self.set_state(table, TableState.ERROR)
raise
self.mark_table_refreshed(table, total, full)
self.set_state(table, TableState.READY)
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
def execute_in_memory(
self, sql: str, params: tuple | list | dict | None = None
@@ -232,7 +295,7 @@ class CacheManager:
return row[0] if row else None
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count."""
"""Insert-or-replace one batch of *rows* by the table's unique key."""
col_list = ", ".join(columns)
placeholders = ", ".join("?" * len(columns))
clean_rows = [coerce_row(row) for row in rows]
@@ -242,8 +305,10 @@ class CacheManager:
clean_rows,
)
self._mem_conn.commit()
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
self.mark_table_refreshed(table, count, self.is_table_full(table))
def count_rows(self, table: str) -> int:
row = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
return int(row[0]) if row else 0
def reset(self) -> None:
"""Wipe the entire cache — every cached table plus the on-disk file."""
@@ -262,6 +327,7 @@ class CacheManager:
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
self._mem_conn.commit()
self._states.clear()
try:
if self._db_path.exists():
self._db_path.unlink()