Batch large-table loads to bound memory and add per-table state to stats
This commit is contained in:
+82
-16
@@ -9,6 +9,8 @@ from loguru import logger
|
||||
|
||||
import sqlmem._meta as _meta
|
||||
from ._coerce import coerce_params, coerce_row
|
||||
from .config import FETCH_BATCH_SIZE
|
||||
from .stats import TableState
|
||||
|
||||
SCHEMA_VERSION = 3
|
||||
|
||||
@@ -18,11 +20,14 @@ class CacheManager:
|
||||
self._db_path = db_path
|
||||
self._backup_interval = backup_interval
|
||||
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
self._lock = threading.Lock()
|
||||
self._lock = threading.Lock() # serializes connection access
|
||||
self._load_lock = threading.Lock() # serializes full table loads
|
||||
self._states: dict[str, str] = {} # table → live processing state
|
||||
self._closed = False
|
||||
|
||||
self._ensure_meta_tables()
|
||||
self._load_from_disk()
|
||||
self._drop_orphan_staging()
|
||||
self._start_backup_thread()
|
||||
|
||||
atexit.register(self._backup_to_disk)
|
||||
@@ -88,6 +93,21 @@ class CacheManager:
|
||||
finally:
|
||||
disk_conn.close()
|
||||
|
||||
def _drop_orphan_staging(self) -> None:
|
||||
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
|
||||
orphans = [
|
||||
r[0]
|
||||
for r in self._mem_conn.execute(
|
||||
r"SELECT name FROM sqlite_master "
|
||||
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
|
||||
).fetchall()
|
||||
]
|
||||
for name in orphans:
|
||||
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
|
||||
if orphans:
|
||||
self._mem_conn.commit()
|
||||
|
||||
def _backup_to_disk(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
@@ -161,6 +181,15 @@ class CacheManager:
|
||||
logger.debug(f"{table!r} has columns: {columns}")
|
||||
return columns
|
||||
|
||||
def set_state(self, table: str, state: str) -> None:
|
||||
self._states[table] = state
|
||||
|
||||
def get_states(self) -> dict[str, str]:
|
||||
return dict(self._states)
|
||||
|
||||
def clear_state(self, table: str) -> None:
|
||||
self._states.pop(table, None)
|
||||
|
||||
def load_table(
|
||||
self,
|
||||
table: str,
|
||||
@@ -168,21 +197,55 @@ class CacheManager:
|
||||
source_conn: sqlite3.Connection,
|
||||
full: bool = False,
|
||||
) -> None:
|
||||
"""Stream the source table into the cache in batches.
|
||||
|
||||
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
|
||||
table and swapped in atomically, so peak memory stays bounded (no
|
||||
``fetchall`` of a huge table) and readers keep seeing the previous copy
|
||||
until the swap. Concurrent loads are serialized by ``_load_lock``; the
|
||||
connection lock is only held for the brief per-batch inserts and the swap.
|
||||
"""
|
||||
cols = ", ".join(columns)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
|
||||
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
|
||||
clean_rows = [coerce_row(row) for row in rows]
|
||||
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
staging = f"{table}__sqlmem_load"
|
||||
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
||||
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows)
|
||||
self._mem_conn.commit()
|
||||
with self._load_lock:
|
||||
self.set_state(table, TableState.LOADING)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB (batch={FETCH_BATCH_SIZE})")
|
||||
try:
|
||||
cursor = source_conn.execute(f"SELECT {cols} FROM {table}")
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
|
||||
self._mem_conn.execute(f"CREATE TABLE {staging} ({col_defs})")
|
||||
self._mem_conn.commit()
|
||||
|
||||
self.mark_table_refreshed(table, len(rows), full)
|
||||
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
|
||||
total = 0
|
||||
insert_sql = f"INSERT INTO {staging} VALUES ({placeholders})"
|
||||
while True:
|
||||
batch = cursor.fetchmany(FETCH_BATCH_SIZE) # network outside _lock
|
||||
if not batch:
|
||||
break
|
||||
clean = [coerce_row(row) for row in batch]
|
||||
with self._lock:
|
||||
self._mem_conn.executemany(insert_sql, clean)
|
||||
self._mem_conn.commit()
|
||||
total += len(batch)
|
||||
|
||||
with self._lock: # atomic swap — readers see old or new, never partial
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
self._mem_conn.execute(f"ALTER TABLE {staging} RENAME TO {table}")
|
||||
self._mem_conn.commit()
|
||||
except BaseException:
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
|
||||
self._mem_conn.commit()
|
||||
self.set_state(table, TableState.ERROR)
|
||||
raise
|
||||
|
||||
self.mark_table_refreshed(table, total, full)
|
||||
self.set_state(table, TableState.READY)
|
||||
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
|
||||
|
||||
def execute_in_memory(
|
||||
self, sql: str, params: tuple | list | dict | None = None
|
||||
@@ -232,7 +295,7 @@ class CacheManager:
|
||||
return row[0] if row else None
|
||||
|
||||
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
|
||||
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count."""
|
||||
"""Insert-or-replace one batch of *rows* by the table's unique key."""
|
||||
col_list = ", ".join(columns)
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
clean_rows = [coerce_row(row) for row in rows]
|
||||
@@ -242,8 +305,10 @@ class CacheManager:
|
||||
clean_rows,
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
|
||||
self.mark_table_refreshed(table, count, self.is_table_full(table))
|
||||
|
||||
def count_rows(self, table: str) -> int:
|
||||
row = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Wipe the entire cache — every cached table plus the on-disk file."""
|
||||
@@ -262,6 +327,7 @@ class CacheManager:
|
||||
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
|
||||
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
|
||||
self._mem_conn.commit()
|
||||
self._states.clear()
|
||||
try:
|
||||
if self._db_path.exists():
|
||||
self._db_path.unlink()
|
||||
|
||||
Reference in New Issue
Block a user