Batch large-table loads to bound memory and add per-table state to stats

This commit is contained in:
Jan Doubravský
2026-06-05 14:44:07 +02:00
parent 85bb84a1a6
commit 286a5f207d
11 changed files with 436 additions and 29 deletions
+14
View File
@@ -6,6 +6,20 @@ All notable changes to this project will be documented in this file.
--- ---
## [1.5.0] - 2026-06-05
### Added
- **Per-table processing state in `stats`** — `TableStats` now carries `state` (`loading` / `refreshing` / `ready` / `stale` / `error`) and `tracking` (`delta` / `ttl` / `static`), so callers can see whether each table is up to date or being processed. In-progress first loads and failed loads also surface in `stats.tables`.
- `SQLMEM_FETCH_BATCH` env var (default `10000`) — rows fetched per batch when loading a table.
### Changed
- `pyproject.toml` — bumped version to `1.5.0`
- **Large-table loads are streamed in batches** — `load_table` no longer `fetchall()`s the whole table (which double-buffered every row in Python and could OOM/crash on tens of millions of rows). Rows are now fetched `SQLMEM_FETCH_BATCH` at a time into a staging table and swapped in atomically, so peak memory stays bounded, the previous copy stays queryable during a reload, and the network fetch no longer holds the cache lock. Delta catch-ups are streamed the same way.
- Orphan staging tables left by an interrupted load (crash/backup mid-load) are dropped on startup.
- Delta upserts compute `row_count` once per refresh instead of a full `COUNT(*)` after every batch (avoids O(rows×batches) work on large catch-ups).
---
## [1.4.0] - 2026-06-05 ## [1.4.0] - 2026-06-05
### Fixed ### Fixed
+26 -1
View File
@@ -245,9 +245,33 @@ Use `reset()` after a **structural change** in the source (columns added/removed
stats = engine.stats # Stats snapshot stats = engine.stats # Stats snapshot
print(stats.hits, stats.misses, stats.refetches) print(stats.hits, stats.misses, stats.refetches)
for name, t in stats.tables.items(): for name, t in stats.tables.items():
print(name, t.rows, t.columns, t.last_refresh) print(name, t.rows, t.state, t.tracking, t.last_refresh)
``` ```
Each `TableStats` reports a live processing **state** and how the table is kept fresh (**tracking**):
| `state` | Meaning |
|---|---|
| `loading` | a full load is in progress |
| `refreshing` | an incremental (delta) refresh is in progress |
| `ready` | cached and idle (up to date) |
| `stale` | a TTL table whose cache has expired; reloads on next access |
| `error` | the last load failed |
| `tracking` | Meaning |
|---|---|
| `delta` | kept in sync incrementally via a change column |
| `ttl` | full-reloaded when older than its TTL |
| `static` | loaded on demand, never auto-refreshed |
## Memory and very large tables
The cache is **in-memory SQLite**, so a cached table lives in RAM — it must fit in available memory. To keep huge tables manageable:
- **Loads are streamed in batches** (`SQLMEM_FETCH_BATCH` rows at a time, default 10 000) into a staging table and swapped in atomically. A multi-million-row table never gets fully materialized in Python at once, so the load doesn't spike memory or crash the process, and readers keep seeing the previous copy until the swap completes.
- Use **[delta refresh](#incremental-delta-refresh)** for large tables that have a change column — after the first load only changed rows are pulled, so restarts and refreshes don't re-read the whole table.
- A **single query that returns a huge result set** (e.g. `SELECT *` over a multi-million-row cached table) still materializes that result as a list of dicts; bound it with a `WHERE`/`LIMIT` rather than selecting everything.
## Configuration ## Configuration
Set via environment variables or a `.env` file: Set via environment variables or a `.env` file:
@@ -259,6 +283,7 @@ Set via environment variables or a `.env` file:
| `SQLMEM_BACKUP_INTERVAL` | `3600` | Disk backup interval in seconds | | `SQLMEM_BACKUP_INTERVAL` | `3600` | Disk backup interval in seconds |
| `SQLMEM_SQL_DIALECT` | `tsql` | sqlglot dialect used to parse incoming SQL (e.g. `tsql`, `postgres`, `mysql`) | | `SQLMEM_SQL_DIALECT` | `tsql` | sqlglot dialect used to parse incoming SQL (e.g. `tsql`, `postgres`, `mysql`) |
| `SQLMEM_REFRESH_INTERVAL` | `300` | background refresh tick (seconds) — delta pulls and proactive TTL reloads | | `SQLMEM_REFRESH_INTERVAL` | `300` | background refresh tick (seconds) — delta pulls and proactive TTL reloads |
| `SQLMEM_FETCH_BATCH` | `10000` | rows fetched per batch when loading a table — caps peak memory for huge tables |
## Exceptions ## Exceptions
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "sqlmem" name = "sqlmem"
version = "1.4.0" version = "1.5.0"
description = "" description = ""
authors = [ authors = [
{name = "jan.doubravsky@gmail.com"} {name = "jan.doubravsky@gmail.com"}
+82 -16
View File
@@ -9,6 +9,8 @@ from loguru import logger
import sqlmem._meta as _meta import sqlmem._meta as _meta
from ._coerce import coerce_params, coerce_row from ._coerce import coerce_params, coerce_row
from .config import FETCH_BATCH_SIZE
from .stats import TableState
SCHEMA_VERSION = 3 SCHEMA_VERSION = 3
@@ -18,11 +20,14 @@ class CacheManager:
self._db_path = db_path self._db_path = db_path
self._backup_interval = backup_interval self._backup_interval = backup_interval
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False) self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
self._lock = threading.Lock() self._lock = threading.Lock() # serializes connection access
self._load_lock = threading.Lock() # serializes full table loads
self._states: dict[str, str] = {} # table → live processing state
self._closed = False self._closed = False
self._ensure_meta_tables() self._ensure_meta_tables()
self._load_from_disk() self._load_from_disk()
self._drop_orphan_staging()
self._start_backup_thread() self._start_backup_thread()
atexit.register(self._backup_to_disk) atexit.register(self._backup_to_disk)
@@ -88,6 +93,21 @@ class CacheManager:
finally: finally:
disk_conn.close() disk_conn.close()
def _drop_orphan_staging(self) -> None:
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
orphans = [
r[0]
for r in self._mem_conn.execute(
r"SELECT name FROM sqlite_master "
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
).fetchall()
]
for name in orphans:
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
if orphans:
self._mem_conn.commit()
def _backup_to_disk(self) -> None: def _backup_to_disk(self) -> None:
if self._closed: if self._closed:
return return
@@ -161,6 +181,15 @@ class CacheManager:
logger.debug(f"{table!r} has columns: {columns}") logger.debug(f"{table!r} has columns: {columns}")
return columns return columns
def set_state(self, table: str, state: str) -> None:
self._states[table] = state
def get_states(self) -> dict[str, str]:
return dict(self._states)
def clear_state(self, table: str) -> None:
self._states.pop(table, None)
def load_table( def load_table(
self, self,
table: str, table: str,
@@ -168,21 +197,55 @@ class CacheManager:
source_conn: sqlite3.Connection, source_conn: sqlite3.Connection,
full: bool = False, full: bool = False,
) -> None: ) -> None:
"""Stream the source table into the cache in batches.
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
table and swapped in atomically, so peak memory stays bounded (no
``fetchall`` of a huge table) and readers keep seeing the previous copy
until the swap. Concurrent loads are serialized by ``_load_lock``; the
connection lock is only held for the brief per-batch inserts and the swap.
"""
cols = ", ".join(columns) cols = ", ".join(columns)
logger.info(f"Fetching {table!r} columns [{cols}] from source DB") col_defs = ", ".join(f"{c} TEXT" for c in columns)
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall() placeholders = ", ".join("?" * len(columns))
clean_rows = [coerce_row(row) for row in rows] staging = f"{table}__sqlmem_load"
with self._lock: with self._load_lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}") self.set_state(table, TableState.LOADING)
col_defs = ", ".join(f"{c} TEXT" for c in columns) logger.info(f"Fetching {table!r} columns [{cols}] from source DB (batch={FETCH_BATCH_SIZE})")
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})") try:
placeholders = ", ".join("?" * len(columns)) cursor = source_conn.execute(f"SELECT {cols} FROM {table}")
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows) with self._lock:
self._mem_conn.commit() self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
self._mem_conn.execute(f"CREATE TABLE {staging} ({col_defs})")
self._mem_conn.commit()
self.mark_table_refreshed(table, len(rows), full) total = 0
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})") insert_sql = f"INSERT INTO {staging} VALUES ({placeholders})"
while True:
batch = cursor.fetchmany(FETCH_BATCH_SIZE) # network outside _lock
if not batch:
break
clean = [coerce_row(row) for row in batch]
with self._lock:
self._mem_conn.executemany(insert_sql, clean)
self._mem_conn.commit()
total += len(batch)
with self._lock: # atomic swap — readers see old or new, never partial
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
self._mem_conn.execute(f"ALTER TABLE {staging} RENAME TO {table}")
self._mem_conn.commit()
except BaseException:
with self._lock:
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
self._mem_conn.commit()
self.set_state(table, TableState.ERROR)
raise
self.mark_table_refreshed(table, total, full)
self.set_state(table, TableState.READY)
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
def execute_in_memory( def execute_in_memory(
self, sql: str, params: tuple | list | dict | None = None self, sql: str, params: tuple | list | dict | None = None
@@ -232,7 +295,7 @@ class CacheManager:
return row[0] if row else None return row[0] if row else None
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None: def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count.""" """Insert-or-replace one batch of *rows* by the table's unique key."""
col_list = ", ".join(columns) col_list = ", ".join(columns)
placeholders = ", ".join("?" * len(columns)) placeholders = ", ".join("?" * len(columns))
clean_rows = [coerce_row(row) for row in rows] clean_rows = [coerce_row(row) for row in rows]
@@ -242,8 +305,10 @@ class CacheManager:
clean_rows, clean_rows,
) )
self._mem_conn.commit() self._mem_conn.commit()
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
self.mark_table_refreshed(table, count, self.is_table_full(table)) def count_rows(self, table: str) -> int:
row = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
return int(row[0]) if row else 0
def reset(self) -> None: def reset(self) -> None:
"""Wipe the entire cache — every cached table plus the on-disk file.""" """Wipe the entire cache — every cached table plus the on-disk file."""
@@ -262,6 +327,7 @@ class CacheManager:
self._mem_conn.execute("DELETE FROM _sqlmem_tables") self._mem_conn.execute("DELETE FROM _sqlmem_tables")
self._mem_conn.execute("DELETE FROM _sqlmem_columns") self._mem_conn.execute("DELETE FROM _sqlmem_columns")
self._mem_conn.commit() self._mem_conn.commit()
self._states.clear()
try: try:
if self._db_path.exists(): if self._db_path.exists():
self._db_path.unlink() self._db_path.unlink()
+2
View File
@@ -11,6 +11,8 @@ CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600")) BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
# How often (seconds) the background thread pulls deltas for delta-tracked tables. # How often (seconds) the background thread pulls deltas for delta-tracked tables.
REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300")) REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300"))
# Rows fetched per batch when loading a table — caps peak memory for huge tables.
FETCH_BATCH_SIZE = int(os.getenv("SQLMEM_FETCH_BATCH", "10000"))
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server), # Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite. # which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql") SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
+24 -6
View File
@@ -4,6 +4,8 @@ from dataclasses import dataclass, field
from loguru import logger from loguru import logger
from .cache import CacheManager from .cache import CacheManager
from .config import FETCH_BATCH_SIZE
from .stats import TableState
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -58,21 +60,37 @@ class DeltaRefresher:
col_list = ", ".join(columns) col_list = ", ".join(columns)
if watermark is None: if watermark is None:
rows = source_conn.execute(f"SELECT {col_list} FROM {table}").fetchall() cursor = source_conn.execute(f"SELECT {col_list} FROM {table}")
else: else:
rows = source_conn.execute( cursor = source_conn.execute(
f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?", f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?",
(watermark,), (watermark,),
).fetchall() )
if not rows: # Stream the delta in batches so a large catch-up never materializes at once.
total = 0
self._cache.set_state(table, TableState.REFRESHING)
try:
while True:
batch = cursor.fetchmany(FETCH_BATCH_SIZE)
if not batch:
break
self._cache.upsert_rows(table, columns, batch)
total += len(batch)
finally:
self._cache.set_state(table, TableState.READY)
if total == 0:
logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}") logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}")
return return
self._cache.upsert_rows(table, columns, rows) # Update row_count / last_refresh once (not per batch) and advance the watermark.
self._cache.mark_table_refreshed(
table, self._cache.count_rows(table), self._cache.is_table_full(table)
)
new_watermark = self._cache.max_value(table, cfg.change_column) new_watermark = self._cache.max_value(table, cfg.change_column)
self._cache.set_last_synced_at(table, new_watermark) self._cache.set_last_synced_at(table, new_watermark)
logger.info( logger.info(
f"Delta refresh {table!r}: {len(rows)} row(s) upserted, " f"Delta refresh {table!r}: {total} row(s) upserted, "
f"watermark {watermark!r}{new_watermark!r}" f"watermark {watermark!r}{new_watermark!r}"
) )
+22 -2
View File
@@ -1,5 +1,6 @@
import sqlite3 import sqlite3
import threading import threading
from dataclasses import replace
from typing import cast from typing import cast
from loguru import logger from loguru import logger
@@ -12,7 +13,7 @@ from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
from .executor import QueryExecutor from .executor import QueryExecutor
from .parser import Params, parse from .parser import Params, parse
from .registry import ColumnRegistry from .registry import ColumnRegistry
from .stats import Stats, StatsCollector from .stats import Stats, StatsCollector, TableState, TableStats
class CachingEngine: class CachingEngine:
@@ -68,8 +69,26 @@ class CachingEngine:
@property @property
def stats(self) -> Stats: def stats(self) -> Stats:
states = self._cache.get_states()
with self._cache._lock: with self._cache._lock:
return self._stats.snapshot(self._cache.connection) base = self._stats.snapshot(self._cache.connection, states)
return replace(base, tables={n: self._enrich(n, t) for n, t in base.tables.items()})
def _enrich(self, name: str, table_stats: TableStats) -> TableStats:
"""Annotate a TableStats with how it is refreshed and TTL staleness."""
if name in self._delta:
tracking = "delta"
elif name in self._ttl:
tracking = "ttl"
else:
tracking = "static"
state = table_stats.state
if state == TableState.READY and name in self._ttl:
age = self._cache.seconds_since_refresh(name)
if age is not None and age > self._ttl[name]:
state = TableState.STALE
return replace(table_stats, tracking=tracking, state=state)
def execute(self, sql: str, params: Params = None) -> list[dict]: def execute(self, sql: str, params: Params = None) -> list[dict]:
parsed = parse(sql, params) parsed = parse(sql, params)
@@ -130,6 +149,7 @@ class CachingEngine:
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,) "DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
) )
self._cache.connection.commit() self._cache.connection.commit()
self._cache.clear_state(table)
def reset(self) -> None: def reset(self) -> None:
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes.""" """Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
+24 -1
View File
@@ -3,11 +3,23 @@ import threading
from dataclasses import dataclass from dataclasses import dataclass
class TableState:
"""Live processing state of a cached table (value of ``TableStats.state``)."""
LOADING = "loading" # a full load is in progress
REFRESHING = "refreshing" # an incremental (delta) refresh is in progress
READY = "ready" # cached and idle
STALE = "stale" # TTL expired — will reload on next access
ERROR = "error" # the last load failed
@dataclass(frozen=True) @dataclass(frozen=True)
class TableStats: class TableStats:
rows: int rows: int
columns: list[str] columns: list[str]
last_refresh: str last_refresh: str
state: str = TableState.READY
tracking: str = "static" # "delta" | "ttl" | "static"
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -37,14 +49,19 @@ class StatsCollector:
with self._lock: with self._lock:
self.refetches += 1 self.refetches += 1
def snapshot(self, conn: sqlite3.Connection) -> Stats: def snapshot(
self, conn: sqlite3.Connection, states: dict[str, str] | None = None
) -> Stats:
states = states or {}
with self._lock: with self._lock:
hits, misses, refetches = self.hits, self.misses, self.refetches hits, misses, refetches = self.hits, self.misses, self.refetches
tables: dict[str, TableStats] = {} tables: dict[str, TableStats] = {}
cached: set[str] = set()
for table_name, row_count, last_refresh in conn.execute( for table_name, row_count, last_refresh in conn.execute(
"SELECT table_name, row_count, last_refresh_at FROM _sqlmem_tables" "SELECT table_name, row_count, last_refresh_at FROM _sqlmem_tables"
).fetchall(): ).fetchall():
cached.add(table_name)
columns = [ columns = [
r[0] r[0]
for r in conn.execute( for r in conn.execute(
@@ -56,6 +73,12 @@ class StatsCollector:
rows=row_count or 0, rows=row_count or 0,
columns=columns, columns=columns,
last_refresh=last_refresh, last_refresh=last_refresh,
state=states.get(table_name, TableState.READY),
) )
# Surface tables that are mid-first-load (not yet in _sqlmem_tables) or failed.
for name, state in states.items():
if name not in cached and state in (TableState.LOADING, TableState.ERROR):
tables[name] = TableStats(rows=0, columns=[], last_refresh="", state=state)
return Stats(hits=hits, misses=misses, refetches=refetches, tables=tables) return Stats(hits=hits, misses=misses, refetches=refetches, tables=tables)
+10 -2
View File
@@ -10,11 +10,19 @@ from sqlmem.cache import CacheManager
class _FakeCursor: class _FakeCursor:
def __init__(self, rows): def __init__(self, rows):
self._rows = rows self._rows = list(rows)
self._pos = 0
self.description = None self.description = None
def fetchall(self): def fetchall(self):
return self._rows out = self._rows[self._pos :]
self._pos = len(self._rows)
return out
def fetchmany(self, size):
out = self._rows[self._pos : self._pos + size]
self._pos += len(out)
return out
class FakeSource: class FakeSource:
+105
View File
@@ -0,0 +1,105 @@
import sqlite3
import pytest
from sqlmem.cache import CacheManager
@pytest.fixture
def source_conn():
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE big (id TEXT, val TEXT)")
conn.executemany(
"INSERT INTO big VALUES (?, ?)", [(str(i), f"v{i}") for i in range(5)]
)
conn.commit()
yield conn
conn.close()
@pytest.fixture
def cache(tmp_path):
c = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
yield c
c.close()
@pytest.fixture
def small_batches(monkeypatch):
# Force multiple fetch batches over the 5 source rows.
monkeypatch.setattr("sqlmem.cache.FETCH_BATCH_SIZE", 2)
def test_batched_load_loads_all_rows(cache, source_conn, small_batches):
cache.load_table("big", ["id", "val"], source_conn)
_, rows = cache.execute_in_memory(
"SELECT id, val FROM big ORDER BY CAST(id AS INTEGER)"
)
assert len(rows) == 5
assert rows[0] == ("0", "v0")
assert rows[-1] == ("4", "v4")
def test_no_staging_table_left_behind(cache, source_conn, small_batches):
cache.load_table("big", ["id", "val"], source_conn)
names = {
r[0]
for r in cache.connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
}
assert "big" in names
assert not any(n.endswith("__sqlmem_load") for n in names)
def test_reload_replaces_data_atomically(cache, source_conn, small_batches):
cache.load_table("big", ["id", "val"], source_conn)
source_conn.execute("DELETE FROM big")
source_conn.execute("INSERT INTO big VALUES ('99', 'new')")
source_conn.commit()
cache.load_table("big", ["id", "val"], source_conn)
_, rows = cache.execute_in_memory("SELECT id, val FROM big")
assert rows == [("99", "new")]
def test_load_sets_ready_state(cache, source_conn):
cache.load_table("big", ["id", "val"], source_conn)
assert cache.get_states()["big"] == "ready"
def test_orphan_staging_dropped_on_startup(tmp_path, source_conn):
# Simulate a crash mid-load: a staging table persisted into cache.db.
db_path = tmp_path / "cache.db"
c1 = CacheManager(db_path=db_path, backup_interval=9999)
c1.load_table("big", ["id", "val"], source_conn)
c1.connection.execute("CREATE TABLE big__sqlmem_load (id TEXT, val TEXT)")
c1.connection.commit()
c1.close() # backup writes the staging table to disk
c2 = CacheManager(db_path=db_path, backup_interval=9999)
names = {
r[0]
for r in c2.connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
}
c2.close()
assert "big" in names # real table survives
assert not any(n.endswith("__sqlmem_load") for n in names) # orphan cleaned
def test_failed_load_sets_error_state_and_cleans_staging(cache):
empty_source = sqlite3.connect(":memory:") # has no 'big' table
try:
with pytest.raises(sqlite3.OperationalError):
cache.load_table("big", ["id"], empty_source)
assert cache.get_states()["big"] == "error"
names = {
r[0]
for r in cache.connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
}
assert not any(n.endswith("__sqlmem_load") for n in names)
finally:
empty_source.close()
+126
View File
@@ -0,0 +1,126 @@
import sqlite3
import threading
import pytest
from sqlalchemy import create_engine
import sqlmem.engine as eng_mod
from sqlmem import CachingEngine, DeltaConfig
from sqlmem.cache import CacheManager
from sqlmem.stats import StatsCollector
@pytest.fixture
def source_engine(tmp_path):
db_path = tmp_path / "source.db"
conn = sqlite3.connect(db_path)
conn.executescript(
"""
CREATE TABLE products (id TEXT PRIMARY KEY, name TEXT, changed TEXT);
INSERT INTO products VALUES ('1', 'Widget', '2026-06-01 10:00:00');
"""
)
conn.commit()
conn.close()
engine = create_engine(f"sqlite:///{db_path}")
yield engine
engine.dispose()
@pytest.fixture
def patched_cache(tmp_path, monkeypatch):
monkeypatch.setattr(eng_mod, "CACHE_DB_PATH", tmp_path / "cache.db")
monkeypatch.setattr(eng_mod, "BACKUP_INTERVAL_SECONDS", 9999)
def test_static_table_state_and_tracking(source_engine, patched_cache):
engine = CachingEngine(source_engine)
engine.execute("SELECT id, name FROM products")
s = engine.stats.tables["products"]
assert s.state == "ready"
assert s.tracking == "static"
assert s.rows == 1
engine.close()
def test_delta_table_tracking(source_engine, patched_cache):
engine = CachingEngine(
source_engine, delta={"products": DeltaConfig("changed", ["id"])}
)
engine.execute("SELECT id, name FROM products")
s = engine.stats.tables["products"]
assert s.tracking == "delta"
assert s.state == "ready"
engine.close()
def test_ttl_table_reports_stale(source_engine, patched_cache):
engine = CachingEngine(source_engine, ttl={"products": 0})
engine.execute("SELECT id, name FROM products")
s = engine.stats.tables["products"]
assert s.tracking == "ttl"
assert s.state == "stale" # ttl=0 → already past its max age
engine.close()
def test_counters_still_reported(source_engine, patched_cache):
engine = CachingEngine(source_engine)
engine.execute("SELECT id, name FROM products")
engine.execute("SELECT id, name FROM products")
stats = engine.stats
assert stats.misses == 1
assert stats.hits == 1
engine.close()
# --- a table being loaded for the first time shows up as "loading" ----------
def test_snapshot_surfaces_a_loading_table(tmp_path):
cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
snap = StatsCollector().snapshot(cache.connection, {"pending": "loading"})
assert "pending" in snap.tables
assert snap.tables["pending"].state == "loading"
assert snap.tables["pending"].rows == 0
cache.close()
def test_loading_state_visible_from_another_thread_during_load(tmp_path):
"""A first load in progress is observable as 'loading' from another thread."""
cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
started = threading.Event()
release = threading.Event()
class BlockingCursor:
def __init__(self, rows):
self._rows = list(rows)
self._done = False
def fetchmany(self, size):
if self._done:
return []
started.set()
release.wait(5) # hold the load open until the test releases it
self._done = True
return self._rows
class BlockingSource:
def execute(self, sql):
return BlockingCursor([("1", "alice")])
loader = threading.Thread(
target=cache.load_table, args=("users", ["id", "name"], BlockingSource())
)
loader.start()
try:
assert started.wait(5), "load did not start"
# mid-load: not yet in _sqlmem_tables, but surfaced as loading
assert cache.get_states()["users"] == "loading"
snap = StatsCollector().snapshot(cache.connection, cache.get_states())
assert snap.tables["users"].state == "loading"
finally:
release.set()
loader.join(5)
assert not loader.is_alive()
assert cache.get_states()["users"] == "ready"
cache.close()