From 286a5f207d996d5da9340e1ad8fe35c76113c2c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Doubravsk=C3=BD?= Date: Fri, 5 Jun 2026 14:44:07 +0200 Subject: [PATCH] Batch large-table loads to bound memory and add per-table state to stats --- CHANGELOG.md | 14 +++++ README.md | 27 +++++++++- pyproject.toml | 2 +- src/sqlmem/cache.py | 98 +++++++++++++++++++++++++++------ src/sqlmem/config.py | 2 + src/sqlmem/delta.py | 30 ++++++++--- src/sqlmem/engine.py | 24 ++++++++- src/sqlmem/stats.py | 25 ++++++++- tests/test_coerce.py | 12 ++++- tests/test_large.py | 105 ++++++++++++++++++++++++++++++++++++ tests/test_stats.py | 126 +++++++++++++++++++++++++++++++++++++++++++ 11 files changed, 436 insertions(+), 29 deletions(-) create mode 100644 tests/test_large.py create mode 100644 tests/test_stats.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fa5e33e..d1dfb71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,20 @@ All notable changes to this project will be documented in this file. --- +## [1.5.0] - 2026-06-05 + +### Added +- **Per-table processing state in `stats`** — `TableStats` now carries `state` (`loading` / `refreshing` / `ready` / `stale` / `error`) and `tracking` (`delta` / `ttl` / `static`), so callers can see whether each table is up to date or being processed. In-progress first loads and failed loads also surface in `stats.tables`. +- `SQLMEM_FETCH_BATCH` env var (default `10000`) — rows fetched per batch when loading a table. + +### Changed +- `pyproject.toml` — bumped version to `1.5.0` +- **Large-table loads are streamed in batches** — `load_table` no longer `fetchall()`s the whole table (which double-buffered every row in Python and could OOM/crash on tens of millions of rows). Rows are now fetched `SQLMEM_FETCH_BATCH` at a time into a staging table and swapped in atomically, so peak memory stays bounded, the previous copy stays queryable during a reload, and the network fetch no longer holds the cache lock. Delta catch-ups are streamed the same way. +- Orphan staging tables left by an interrupted load (crash/backup mid-load) are dropped on startup. +- Delta upserts compute `row_count` once per refresh instead of a full `COUNT(*)` after every batch (avoids O(rows×batches) work on large catch-ups). + +--- + ## [1.4.0] - 2026-06-05 ### Fixed diff --git a/README.md b/README.md index c057aee..9408ad3 100644 --- a/README.md +++ b/README.md @@ -245,9 +245,33 @@ Use `reset()` after a **structural change** in the source (columns added/removed stats = engine.stats # Stats snapshot print(stats.hits, stats.misses, stats.refetches) for name, t in stats.tables.items(): - print(name, t.rows, t.columns, t.last_refresh) + print(name, t.rows, t.state, t.tracking, t.last_refresh) ``` +Each `TableStats` reports a live processing **state** and how the table is kept fresh (**tracking**): + +| `state` | Meaning | +|---|---| +| `loading` | a full load is in progress | +| `refreshing` | an incremental (delta) refresh is in progress | +| `ready` | cached and idle (up to date) | +| `stale` | a TTL table whose cache has expired; reloads on next access | +| `error` | the last load failed | + +| `tracking` | Meaning | +|---|---| +| `delta` | kept in sync incrementally via a change column | +| `ttl` | full-reloaded when older than its TTL | +| `static` | loaded on demand, never auto-refreshed | + +## Memory and very large tables + +The cache is **in-memory SQLite**, so a cached table lives in RAM — it must fit in available memory. To keep huge tables manageable: + +- **Loads are streamed in batches** (`SQLMEM_FETCH_BATCH` rows at a time, default 10 000) into a staging table and swapped in atomically. A multi-million-row table never gets fully materialized in Python at once, so the load doesn't spike memory or crash the process, and readers keep seeing the previous copy until the swap completes. +- Use **[delta refresh](#incremental-delta-refresh)** for large tables that have a change column — after the first load only changed rows are pulled, so restarts and refreshes don't re-read the whole table. +- A **single query that returns a huge result set** (e.g. `SELECT *` over a multi-million-row cached table) still materializes that result as a list of dicts; bound it with a `WHERE`/`LIMIT` rather than selecting everything. + ## Configuration Set via environment variables or a `.env` file: @@ -259,6 +283,7 @@ Set via environment variables or a `.env` file: | `SQLMEM_BACKUP_INTERVAL` | `3600` | Disk backup interval in seconds | | `SQLMEM_SQL_DIALECT` | `tsql` | sqlglot dialect used to parse incoming SQL (e.g. `tsql`, `postgres`, `mysql`) | | `SQLMEM_REFRESH_INTERVAL` | `300` | background refresh tick (seconds) — delta pulls and proactive TTL reloads | +| `SQLMEM_FETCH_BATCH` | `10000` | rows fetched per batch when loading a table — caps peak memory for huge tables | ## Exceptions diff --git a/pyproject.toml b/pyproject.toml index f44f894..c1d5b12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlmem" -version = "1.4.0" +version = "1.5.0" description = "" authors = [ {name = "jan.doubravsky@gmail.com"} diff --git a/src/sqlmem/cache.py b/src/sqlmem/cache.py index e00e907..c94772a 100644 --- a/src/sqlmem/cache.py +++ b/src/sqlmem/cache.py @@ -9,6 +9,8 @@ from loguru import logger import sqlmem._meta as _meta from ._coerce import coerce_params, coerce_row +from .config import FETCH_BATCH_SIZE +from .stats import TableState SCHEMA_VERSION = 3 @@ -18,11 +20,14 @@ class CacheManager: self._db_path = db_path self._backup_interval = backup_interval self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False) - self._lock = threading.Lock() + self._lock = threading.Lock() # serializes connection access + self._load_lock = threading.Lock() # serializes full table loads + self._states: dict[str, str] = {} # table → live processing state self._closed = False self._ensure_meta_tables() self._load_from_disk() + self._drop_orphan_staging() self._start_backup_thread() atexit.register(self._backup_to_disk) @@ -88,6 +93,21 @@ class CacheManager: finally: disk_conn.close() + def _drop_orphan_staging(self) -> None: + """Drop staging tables left by a load that was interrupted (e.g. crash mid-load).""" + orphans = [ + r[0] + for r in self._mem_conn.execute( + r"SELECT name FROM sqlite_master " + r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'" + ).fetchall() + ] + for name in orphans: + logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.") + self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}") + if orphans: + self._mem_conn.commit() + def _backup_to_disk(self) -> None: if self._closed: return @@ -161,6 +181,15 @@ class CacheManager: logger.debug(f"{table!r} has columns: {columns}") return columns + def set_state(self, table: str, state: str) -> None: + self._states[table] = state + + def get_states(self) -> dict[str, str]: + return dict(self._states) + + def clear_state(self, table: str) -> None: + self._states.pop(table, None) + def load_table( self, table: str, @@ -168,21 +197,55 @@ class CacheManager: source_conn: sqlite3.Connection, full: bool = False, ) -> None: + """Stream the source table into the cache in batches. + + Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging + table and swapped in atomically, so peak memory stays bounded (no + ``fetchall`` of a huge table) and readers keep seeing the previous copy + until the swap. Concurrent loads are serialized by ``_load_lock``; the + connection lock is only held for the brief per-batch inserts and the swap. + """ cols = ", ".join(columns) - logger.info(f"Fetching {table!r} columns [{cols}] from source DB") - rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall() - clean_rows = [coerce_row(row) for row in rows] + col_defs = ", ".join(f"{c} TEXT" for c in columns) + placeholders = ", ".join("?" * len(columns)) + staging = f"{table}__sqlmem_load" - with self._lock: - self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}") - col_defs = ", ".join(f"{c} TEXT" for c in columns) - self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})") - placeholders = ", ".join("?" * len(columns)) - self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows) - self._mem_conn.commit() + with self._load_lock: + self.set_state(table, TableState.LOADING) + logger.info(f"Fetching {table!r} columns [{cols}] from source DB (batch={FETCH_BATCH_SIZE})") + try: + cursor = source_conn.execute(f"SELECT {cols} FROM {table}") + with self._lock: + self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}") + self._mem_conn.execute(f"CREATE TABLE {staging} ({col_defs})") + self._mem_conn.commit() - self.mark_table_refreshed(table, len(rows), full) - logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})") + total = 0 + insert_sql = f"INSERT INTO {staging} VALUES ({placeholders})" + while True: + batch = cursor.fetchmany(FETCH_BATCH_SIZE) # network outside _lock + if not batch: + break + clean = [coerce_row(row) for row in batch] + with self._lock: + self._mem_conn.executemany(insert_sql, clean) + self._mem_conn.commit() + total += len(batch) + + with self._lock: # atomic swap — readers see old or new, never partial + self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}") + self._mem_conn.execute(f"ALTER TABLE {staging} RENAME TO {table}") + self._mem_conn.commit() + except BaseException: + with self._lock: + self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}") + self._mem_conn.commit() + self.set_state(table, TableState.ERROR) + raise + + self.mark_table_refreshed(table, total, full) + self.set_state(table, TableState.READY) + logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})") def execute_in_memory( self, sql: str, params: tuple | list | dict | None = None @@ -232,7 +295,7 @@ class CacheManager: return row[0] if row else None def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None: - """Insert-or-replace *rows* by the table's unique key, then refresh row_count.""" + """Insert-or-replace one batch of *rows* by the table's unique key.""" col_list = ", ".join(columns) placeholders = ", ".join("?" * len(columns)) clean_rows = [coerce_row(row) for row in rows] @@ -242,8 +305,10 @@ class CacheManager: clean_rows, ) self._mem_conn.commit() - count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] - self.mark_table_refreshed(table, count, self.is_table_full(table)) + + def count_rows(self, table: str) -> int: + row = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() + return int(row[0]) if row else 0 def reset(self) -> None: """Wipe the entire cache — every cached table plus the on-disk file.""" @@ -262,6 +327,7 @@ class CacheManager: self._mem_conn.execute("DELETE FROM _sqlmem_tables") self._mem_conn.execute("DELETE FROM _sqlmem_columns") self._mem_conn.commit() + self._states.clear() try: if self._db_path.exists(): self._db_path.unlink() diff --git a/src/sqlmem/config.py b/src/sqlmem/config.py index 482dd60..32d181b 100644 --- a/src/sqlmem/config.py +++ b/src/sqlmem/config.py @@ -11,6 +11,8 @@ CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db")) BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600")) # How often (seconds) the background thread pulls deltas for delta-tracked tables. REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300")) +# Rows fetched per batch when loading a table — caps peak memory for huge tables. +FETCH_BATCH_SIZE = int(os.getenv("SQLMEM_FETCH_BATCH", "10000")) # Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server), # which also accepts ANSI SQL. In-memory queries are always rendered to SQLite. SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql") diff --git a/src/sqlmem/delta.py b/src/sqlmem/delta.py index 0c6a7bd..539ac03 100644 --- a/src/sqlmem/delta.py +++ b/src/sqlmem/delta.py @@ -4,6 +4,8 @@ from dataclasses import dataclass, field from loguru import logger from .cache import CacheManager +from .config import FETCH_BATCH_SIZE +from .stats import TableState @dataclass(frozen=True) @@ -58,21 +60,37 @@ class DeltaRefresher: col_list = ", ".join(columns) if watermark is None: - rows = source_conn.execute(f"SELECT {col_list} FROM {table}").fetchall() + cursor = source_conn.execute(f"SELECT {col_list} FROM {table}") else: - rows = source_conn.execute( + cursor = source_conn.execute( f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?", (watermark,), - ).fetchall() + ) - if not rows: + # Stream the delta in batches so a large catch-up never materializes at once. + total = 0 + self._cache.set_state(table, TableState.REFRESHING) + try: + while True: + batch = cursor.fetchmany(FETCH_BATCH_SIZE) + if not batch: + break + self._cache.upsert_rows(table, columns, batch) + total += len(batch) + finally: + self._cache.set_state(table, TableState.READY) + + if total == 0: logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}") return - self._cache.upsert_rows(table, columns, rows) + # Update row_count / last_refresh once (not per batch) and advance the watermark. + self._cache.mark_table_refreshed( + table, self._cache.count_rows(table), self._cache.is_table_full(table) + ) new_watermark = self._cache.max_value(table, cfg.change_column) self._cache.set_last_synced_at(table, new_watermark) logger.info( - f"Delta refresh {table!r}: {len(rows)} row(s) upserted, " + f"Delta refresh {table!r}: {total} row(s) upserted, " f"watermark {watermark!r} → {new_watermark!r}" ) diff --git a/src/sqlmem/engine.py b/src/sqlmem/engine.py index 981182a..ef7aead 100644 --- a/src/sqlmem/engine.py +++ b/src/sqlmem/engine.py @@ -1,5 +1,6 @@ import sqlite3 import threading +from dataclasses import replace from typing import cast from loguru import logger @@ -12,7 +13,7 @@ from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta from .executor import QueryExecutor from .parser import Params, parse from .registry import ColumnRegistry -from .stats import Stats, StatsCollector +from .stats import Stats, StatsCollector, TableState, TableStats class CachingEngine: @@ -68,8 +69,26 @@ class CachingEngine: @property def stats(self) -> Stats: + states = self._cache.get_states() with self._cache._lock: - return self._stats.snapshot(self._cache.connection) + base = self._stats.snapshot(self._cache.connection, states) + return replace(base, tables={n: self._enrich(n, t) for n, t in base.tables.items()}) + + def _enrich(self, name: str, table_stats: TableStats) -> TableStats: + """Annotate a TableStats with how it is refreshed and TTL staleness.""" + if name in self._delta: + tracking = "delta" + elif name in self._ttl: + tracking = "ttl" + else: + tracking = "static" + + state = table_stats.state + if state == TableState.READY and name in self._ttl: + age = self._cache.seconds_since_refresh(name) + if age is not None and age > self._ttl[name]: + state = TableState.STALE + return replace(table_stats, tracking=tracking, state=state) def execute(self, sql: str, params: Params = None) -> list[dict]: parsed = parse(sql, params) @@ -130,6 +149,7 @@ class CachingEngine: "DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,) ) self._cache.connection.commit() + self._cache.clear_state(table) def reset(self) -> None: """Wipe the whole cache (RAM + cache.db). Use after structural source changes.""" diff --git a/src/sqlmem/stats.py b/src/sqlmem/stats.py index a64affb..5fc81d6 100644 --- a/src/sqlmem/stats.py +++ b/src/sqlmem/stats.py @@ -3,11 +3,23 @@ import threading from dataclasses import dataclass +class TableState: + """Live processing state of a cached table (value of ``TableStats.state``).""" + + LOADING = "loading" # a full load is in progress + REFRESHING = "refreshing" # an incremental (delta) refresh is in progress + READY = "ready" # cached and idle + STALE = "stale" # TTL expired — will reload on next access + ERROR = "error" # the last load failed + + @dataclass(frozen=True) class TableStats: rows: int columns: list[str] last_refresh: str + state: str = TableState.READY + tracking: str = "static" # "delta" | "ttl" | "static" @dataclass(frozen=True) @@ -37,14 +49,19 @@ class StatsCollector: with self._lock: self.refetches += 1 - def snapshot(self, conn: sqlite3.Connection) -> Stats: + def snapshot( + self, conn: sqlite3.Connection, states: dict[str, str] | None = None + ) -> Stats: + states = states or {} with self._lock: hits, misses, refetches = self.hits, self.misses, self.refetches tables: dict[str, TableStats] = {} + cached: set[str] = set() for table_name, row_count, last_refresh in conn.execute( "SELECT table_name, row_count, last_refresh_at FROM _sqlmem_tables" ).fetchall(): + cached.add(table_name) columns = [ r[0] for r in conn.execute( @@ -56,6 +73,12 @@ class StatsCollector: rows=row_count or 0, columns=columns, last_refresh=last_refresh, + state=states.get(table_name, TableState.READY), ) + # Surface tables that are mid-first-load (not yet in _sqlmem_tables) or failed. + for name, state in states.items(): + if name not in cached and state in (TableState.LOADING, TableState.ERROR): + tables[name] = TableStats(rows=0, columns=[], last_refresh="", state=state) + return Stats(hits=hits, misses=misses, refetches=refetches, tables=tables) diff --git a/tests/test_coerce.py b/tests/test_coerce.py index 459a708..3914e60 100644 --- a/tests/test_coerce.py +++ b/tests/test_coerce.py @@ -10,11 +10,19 @@ from sqlmem.cache import CacheManager class _FakeCursor: def __init__(self, rows): - self._rows = rows + self._rows = list(rows) + self._pos = 0 self.description = None def fetchall(self): - return self._rows + out = self._rows[self._pos :] + self._pos = len(self._rows) + return out + + def fetchmany(self, size): + out = self._rows[self._pos : self._pos + size] + self._pos += len(out) + return out class FakeSource: diff --git a/tests/test_large.py b/tests/test_large.py new file mode 100644 index 0000000..17fab97 --- /dev/null +++ b/tests/test_large.py @@ -0,0 +1,105 @@ +import sqlite3 + +import pytest + +from sqlmem.cache import CacheManager + + +@pytest.fixture +def source_conn(): + conn = sqlite3.connect(":memory:") + conn.execute("CREATE TABLE big (id TEXT, val TEXT)") + conn.executemany( + "INSERT INTO big VALUES (?, ?)", [(str(i), f"v{i}") for i in range(5)] + ) + conn.commit() + yield conn + conn.close() + + +@pytest.fixture +def cache(tmp_path): + c = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999) + yield c + c.close() + + +@pytest.fixture +def small_batches(monkeypatch): + # Force multiple fetch batches over the 5 source rows. + monkeypatch.setattr("sqlmem.cache.FETCH_BATCH_SIZE", 2) + + +def test_batched_load_loads_all_rows(cache, source_conn, small_batches): + cache.load_table("big", ["id", "val"], source_conn) + _, rows = cache.execute_in_memory( + "SELECT id, val FROM big ORDER BY CAST(id AS INTEGER)" + ) + assert len(rows) == 5 + assert rows[0] == ("0", "v0") + assert rows[-1] == ("4", "v4") + + +def test_no_staging_table_left_behind(cache, source_conn, small_batches): + cache.load_table("big", ["id", "val"], source_conn) + names = { + r[0] + for r in cache.connection.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ).fetchall() + } + assert "big" in names + assert not any(n.endswith("__sqlmem_load") for n in names) + + +def test_reload_replaces_data_atomically(cache, source_conn, small_batches): + cache.load_table("big", ["id", "val"], source_conn) + source_conn.execute("DELETE FROM big") + source_conn.execute("INSERT INTO big VALUES ('99', 'new')") + source_conn.commit() + cache.load_table("big", ["id", "val"], source_conn) + _, rows = cache.execute_in_memory("SELECT id, val FROM big") + assert rows == [("99", "new")] + + +def test_load_sets_ready_state(cache, source_conn): + cache.load_table("big", ["id", "val"], source_conn) + assert cache.get_states()["big"] == "ready" + + +def test_orphan_staging_dropped_on_startup(tmp_path, source_conn): + # Simulate a crash mid-load: a staging table persisted into cache.db. + db_path = tmp_path / "cache.db" + c1 = CacheManager(db_path=db_path, backup_interval=9999) + c1.load_table("big", ["id", "val"], source_conn) + c1.connection.execute("CREATE TABLE big__sqlmem_load (id TEXT, val TEXT)") + c1.connection.commit() + c1.close() # backup writes the staging table to disk + + c2 = CacheManager(db_path=db_path, backup_interval=9999) + names = { + r[0] + for r in c2.connection.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ).fetchall() + } + c2.close() + assert "big" in names # real table survives + assert not any(n.endswith("__sqlmem_load") for n in names) # orphan cleaned + + +def test_failed_load_sets_error_state_and_cleans_staging(cache): + empty_source = sqlite3.connect(":memory:") # has no 'big' table + try: + with pytest.raises(sqlite3.OperationalError): + cache.load_table("big", ["id"], empty_source) + assert cache.get_states()["big"] == "error" + names = { + r[0] + for r in cache.connection.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ).fetchall() + } + assert not any(n.endswith("__sqlmem_load") for n in names) + finally: + empty_source.close() diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 0000000..fce74a0 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,126 @@ +import sqlite3 +import threading + +import pytest +from sqlalchemy import create_engine + +import sqlmem.engine as eng_mod +from sqlmem import CachingEngine, DeltaConfig +from sqlmem.cache import CacheManager +from sqlmem.stats import StatsCollector + + +@pytest.fixture +def source_engine(tmp_path): + db_path = tmp_path / "source.db" + conn = sqlite3.connect(db_path) + conn.executescript( + """ + CREATE TABLE products (id TEXT PRIMARY KEY, name TEXT, changed TEXT); + INSERT INTO products VALUES ('1', 'Widget', '2026-06-01 10:00:00'); + """ + ) + conn.commit() + conn.close() + engine = create_engine(f"sqlite:///{db_path}") + yield engine + engine.dispose() + + +@pytest.fixture +def patched_cache(tmp_path, monkeypatch): + monkeypatch.setattr(eng_mod, "CACHE_DB_PATH", tmp_path / "cache.db") + monkeypatch.setattr(eng_mod, "BACKUP_INTERVAL_SECONDS", 9999) + + +def test_static_table_state_and_tracking(source_engine, patched_cache): + engine = CachingEngine(source_engine) + engine.execute("SELECT id, name FROM products") + s = engine.stats.tables["products"] + assert s.state == "ready" + assert s.tracking == "static" + assert s.rows == 1 + engine.close() + + +def test_delta_table_tracking(source_engine, patched_cache): + engine = CachingEngine( + source_engine, delta={"products": DeltaConfig("changed", ["id"])} + ) + engine.execute("SELECT id, name FROM products") + s = engine.stats.tables["products"] + assert s.tracking == "delta" + assert s.state == "ready" + engine.close() + + +def test_ttl_table_reports_stale(source_engine, patched_cache): + engine = CachingEngine(source_engine, ttl={"products": 0}) + engine.execute("SELECT id, name FROM products") + s = engine.stats.tables["products"] + assert s.tracking == "ttl" + assert s.state == "stale" # ttl=0 → already past its max age + engine.close() + + +def test_counters_still_reported(source_engine, patched_cache): + engine = CachingEngine(source_engine) + engine.execute("SELECT id, name FROM products") + engine.execute("SELECT id, name FROM products") + stats = engine.stats + assert stats.misses == 1 + assert stats.hits == 1 + engine.close() + + +# --- a table being loaded for the first time shows up as "loading" ---------- + + +def test_snapshot_surfaces_a_loading_table(tmp_path): + cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999) + snap = StatsCollector().snapshot(cache.connection, {"pending": "loading"}) + assert "pending" in snap.tables + assert snap.tables["pending"].state == "loading" + assert snap.tables["pending"].rows == 0 + cache.close() + + +def test_loading_state_visible_from_another_thread_during_load(tmp_path): + """A first load in progress is observable as 'loading' from another thread.""" + cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999) + started = threading.Event() + release = threading.Event() + + class BlockingCursor: + def __init__(self, rows): + self._rows = list(rows) + self._done = False + + def fetchmany(self, size): + if self._done: + return [] + started.set() + release.wait(5) # hold the load open until the test releases it + self._done = True + return self._rows + + class BlockingSource: + def execute(self, sql): + return BlockingCursor([("1", "alice")]) + + loader = threading.Thread( + target=cache.load_table, args=("users", ["id", "name"], BlockingSource()) + ) + loader.start() + try: + assert started.wait(5), "load did not start" + # mid-load: not yet in _sqlmem_tables, but surfaced as loading + assert cache.get_states()["users"] == "loading" + snap = StatsCollector().snapshot(cache.connection, cache.get_states()) + assert snap.tables["users"].state == "loading" + finally: + release.set() + loader.join(5) + assert not loader.is_alive() + assert cache.get_states()["users"] == "ready" + cache.close()