Fix cache stampede with double-checked locking in load_table
This commit is contained in:
@@ -290,6 +290,74 @@ def test_vacuum_in_memory_is_noop(cache, source_conn):
|
||||
assert cache.is_table_cached("users") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Double-checked locking against cache stampede (1.15.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _ExplodingSource:
|
||||
def execute(self, *args):
|
||||
raise AssertionError("source must not be queried when recheck() is True")
|
||||
|
||||
|
||||
def test_load_table_recheck_true_skips_load(cache, source_conn):
|
||||
"""A recheck that reports the table already satisfied skips the reload."""
|
||||
cache.load_table("users", ["name"], source_conn)
|
||||
# Second load with recheck() → True must not touch the source at all.
|
||||
cache.load_table("users", ["name"], _ExplodingSource(), recheck=lambda: True)
|
||||
assert cache.is_table_cached("users") is True
|
||||
|
||||
|
||||
def test_concurrent_loads_dedup_via_double_checked_lock(tmp_path):
|
||||
"""A second loader queued behind a slow cold load must not reload the table."""
|
||||
import time
|
||||
|
||||
c = CacheManager(db_path=tmp_path / "c.db", backup_interval=9999)
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
loads: list[str] = []
|
||||
|
||||
class _GatedCursor:
|
||||
def __init__(self, rows):
|
||||
self._rows = list(rows)
|
||||
self._done = False
|
||||
|
||||
def fetchmany(self, n):
|
||||
if self._done:
|
||||
return []
|
||||
self._done = True
|
||||
return self._rows
|
||||
|
||||
class _GatedSource:
|
||||
def execute(self, sql):
|
||||
loads.append(sql) # one entry per *actual* source load
|
||||
started.set()
|
||||
release.wait(5) # hold the load open (and _load_lock) until released
|
||||
return _GatedCursor([("alice",), ("bob",)])
|
||||
|
||||
def recheck() -> bool:
|
||||
return c.is_table_cached("users") and "name" in c.get_table_columns("users")
|
||||
|
||||
def load() -> None:
|
||||
c.load_table("users", ["name"], _GatedSource(), recheck=recheck)
|
||||
|
||||
a = threading.Thread(target=load)
|
||||
b = threading.Thread(target=load)
|
||||
a.start()
|
||||
assert started.wait(5), "first load never started" # A holds _load_lock, mid-fetch
|
||||
b.start()
|
||||
time.sleep(0.2) # give B time to queue on _load_lock
|
||||
release.set() # let A finish; B then re-checks and skips
|
||||
a.join(5)
|
||||
b.join(5)
|
||||
assert not a.is_alive() and not b.is_alive()
|
||||
|
||||
assert len(loads) == 1 # the redundant second load was skipped
|
||||
assert c.is_table_cached("users") is True
|
||||
_, rows = c.execute_in_memory("SELECT name FROM users ORDER BY name")
|
||||
assert [r[0] for r in rows] == ["alice", "bob"]
|
||||
c.close()
|
||||
|
||||
|
||||
def test_incremental_vacuum_warns_without_incremental_auto_vacuum(tmp_path, source_conn):
|
||||
"""Incremental vacuum on a DB that isn't auto_vacuum=INCREMENTAL warns and skips."""
|
||||
from loguru import logger
|
||||
|
||||
Reference in New Issue
Block a user