Batch large-table loads to bound memory and add per-table state to stats

This commit is contained in:
Jan Doubravský
2026-06-05 14:44:07 +02:00
parent 85bb84a1a6
commit 286a5f207d
11 changed files with 436 additions and 29 deletions
+10 -2
View File
@@ -10,11 +10,19 @@ from sqlmem.cache import CacheManager
class _FakeCursor:
def __init__(self, rows):
self._rows = rows
self._rows = list(rows)
self._pos = 0
self.description = None
def fetchall(self):
return self._rows
out = self._rows[self._pos :]
self._pos = len(self._rows)
return out
def fetchmany(self, size):
out = self._rows[self._pos : self._pos + size]
self._pos += len(out)
return out
class FakeSource:
+105
View File
@@ -0,0 +1,105 @@
import sqlite3
import pytest
from sqlmem.cache import CacheManager
@pytest.fixture
def source_conn():
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE big (id TEXT, val TEXT)")
conn.executemany(
"INSERT INTO big VALUES (?, ?)", [(str(i), f"v{i}") for i in range(5)]
)
conn.commit()
yield conn
conn.close()
@pytest.fixture
def cache(tmp_path):
c = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
yield c
c.close()
@pytest.fixture
def small_batches(monkeypatch):
# Force multiple fetch batches over the 5 source rows.
monkeypatch.setattr("sqlmem.cache.FETCH_BATCH_SIZE", 2)
def test_batched_load_loads_all_rows(cache, source_conn, small_batches):
cache.load_table("big", ["id", "val"], source_conn)
_, rows = cache.execute_in_memory(
"SELECT id, val FROM big ORDER BY CAST(id AS INTEGER)"
)
assert len(rows) == 5
assert rows[0] == ("0", "v0")
assert rows[-1] == ("4", "v4")
def test_no_staging_table_left_behind(cache, source_conn, small_batches):
cache.load_table("big", ["id", "val"], source_conn)
names = {
r[0]
for r in cache.connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
}
assert "big" in names
assert not any(n.endswith("__sqlmem_load") for n in names)
def test_reload_replaces_data_atomically(cache, source_conn, small_batches):
cache.load_table("big", ["id", "val"], source_conn)
source_conn.execute("DELETE FROM big")
source_conn.execute("INSERT INTO big VALUES ('99', 'new')")
source_conn.commit()
cache.load_table("big", ["id", "val"], source_conn)
_, rows = cache.execute_in_memory("SELECT id, val FROM big")
assert rows == [("99", "new")]
def test_load_sets_ready_state(cache, source_conn):
cache.load_table("big", ["id", "val"], source_conn)
assert cache.get_states()["big"] == "ready"
def test_orphan_staging_dropped_on_startup(tmp_path, source_conn):
# Simulate a crash mid-load: a staging table persisted into cache.db.
db_path = tmp_path / "cache.db"
c1 = CacheManager(db_path=db_path, backup_interval=9999)
c1.load_table("big", ["id", "val"], source_conn)
c1.connection.execute("CREATE TABLE big__sqlmem_load (id TEXT, val TEXT)")
c1.connection.commit()
c1.close() # backup writes the staging table to disk
c2 = CacheManager(db_path=db_path, backup_interval=9999)
names = {
r[0]
for r in c2.connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
}
c2.close()
assert "big" in names # real table survives
assert not any(n.endswith("__sqlmem_load") for n in names) # orphan cleaned
def test_failed_load_sets_error_state_and_cleans_staging(cache):
empty_source = sqlite3.connect(":memory:") # has no 'big' table
try:
with pytest.raises(sqlite3.OperationalError):
cache.load_table("big", ["id"], empty_source)
assert cache.get_states()["big"] == "error"
names = {
r[0]
for r in cache.connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
}
assert not any(n.endswith("__sqlmem_load") for n in names)
finally:
empty_source.close()
+126
View File
@@ -0,0 +1,126 @@
import sqlite3
import threading
import pytest
from sqlalchemy import create_engine
import sqlmem.engine as eng_mod
from sqlmem import CachingEngine, DeltaConfig
from sqlmem.cache import CacheManager
from sqlmem.stats import StatsCollector
@pytest.fixture
def source_engine(tmp_path):
db_path = tmp_path / "source.db"
conn = sqlite3.connect(db_path)
conn.executescript(
"""
CREATE TABLE products (id TEXT PRIMARY KEY, name TEXT, changed TEXT);
INSERT INTO products VALUES ('1', 'Widget', '2026-06-01 10:00:00');
"""
)
conn.commit()
conn.close()
engine = create_engine(f"sqlite:///{db_path}")
yield engine
engine.dispose()
@pytest.fixture
def patched_cache(tmp_path, monkeypatch):
monkeypatch.setattr(eng_mod, "CACHE_DB_PATH", tmp_path / "cache.db")
monkeypatch.setattr(eng_mod, "BACKUP_INTERVAL_SECONDS", 9999)
def test_static_table_state_and_tracking(source_engine, patched_cache):
engine = CachingEngine(source_engine)
engine.execute("SELECT id, name FROM products")
s = engine.stats.tables["products"]
assert s.state == "ready"
assert s.tracking == "static"
assert s.rows == 1
engine.close()
def test_delta_table_tracking(source_engine, patched_cache):
engine = CachingEngine(
source_engine, delta={"products": DeltaConfig("changed", ["id"])}
)
engine.execute("SELECT id, name FROM products")
s = engine.stats.tables["products"]
assert s.tracking == "delta"
assert s.state == "ready"
engine.close()
def test_ttl_table_reports_stale(source_engine, patched_cache):
engine = CachingEngine(source_engine, ttl={"products": 0})
engine.execute("SELECT id, name FROM products")
s = engine.stats.tables["products"]
assert s.tracking == "ttl"
assert s.state == "stale" # ttl=0 → already past its max age
engine.close()
def test_counters_still_reported(source_engine, patched_cache):
engine = CachingEngine(source_engine)
engine.execute("SELECT id, name FROM products")
engine.execute("SELECT id, name FROM products")
stats = engine.stats
assert stats.misses == 1
assert stats.hits == 1
engine.close()
# --- a table being loaded for the first time shows up as "loading" ----------
def test_snapshot_surfaces_a_loading_table(tmp_path):
cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
snap = StatsCollector().snapshot(cache.connection, {"pending": "loading"})
assert "pending" in snap.tables
assert snap.tables["pending"].state == "loading"
assert snap.tables["pending"].rows == 0
cache.close()
def test_loading_state_visible_from_another_thread_during_load(tmp_path):
"""A first load in progress is observable as 'loading' from another thread."""
cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
started = threading.Event()
release = threading.Event()
class BlockingCursor:
def __init__(self, rows):
self._rows = list(rows)
self._done = False
def fetchmany(self, size):
if self._done:
return []
started.set()
release.wait(5) # hold the load open until the test releases it
self._done = True
return self._rows
class BlockingSource:
def execute(self, sql):
return BlockingCursor([("1", "alice")])
loader = threading.Thread(
target=cache.load_table, args=("users", ["id", "name"], BlockingSource())
)
loader.start()
try:
assert started.wait(5), "load did not start"
# mid-load: not yet in _sqlmem_tables, but surfaced as loading
assert cache.get_states()["users"] == "loading"
snap = StatsCollector().snapshot(cache.connection, cache.get_states())
assert snap.tables["users"].state == "loading"
finally:
release.set()
loader.join(5)
assert not loader.is_alive()
assert cache.get_states()["users"] == "ready"
cache.close()