Fix cache stampede with double-checked locking in load_table

This commit is contained in:
Jan Doubravský
2026-06-11 13:03:22 +02:00
parent a68b8994e3
commit 46370fe651
7 changed files with 139 additions and 7 deletions
+15
View File
@@ -2,6 +2,7 @@ import atexit
import signal
import sqlite3
import threading
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
@@ -425,6 +426,7 @@ class CacheManager:
columns: list[str],
source_conn: sqlite3.Connection,
full: bool = False,
recheck: Callable[[], bool] | None = None,
) -> None:
"""Stream the source table into the cache in batches.
@@ -433,6 +435,13 @@ class CacheManager:
``fetchall`` of a huge table) and readers keep seeing the previous copy
until the swap. Concurrent loads are serialized by ``_load_lock``; the
connection lock is only held for the brief per-batch inserts and the swap.
*recheck* implements double-checked locking against a cache stampede: the
decision to load is made by the caller *before* ``_load_lock`` is held, so
on a slow cold load a second request for the same table can queue behind
the lock and then redundantly reload it. If given, ``recheck()`` is
re-evaluated *after* the lock is acquired; when it returns ``True`` the
table is already loaded and fresh, so the load is skipped.
"""
src_cols = ", ".join(quote_source(c, self._dialect) for c in columns)
dt_cols = set(self._datetime_columns.get(table, ()))
@@ -446,6 +455,12 @@ class CacheManager:
q_table = quote(table)
with self._load_lock:
if recheck is not None and recheck():
logger.info(
f"Skipping load of {table!r}: a concurrent loader already "
"satisfied it (double-checked lock)."
)
return
self.set_state(table, TableState.LOADING)
logger.info(f"Fetching {table!r} columns {columns} from source DB (batch={self._fetch_batch})")
try:
+41 -5
View File
@@ -1,3 +1,4 @@
from collections.abc import Callable
from typing import Any
from loguru import logger
@@ -47,6 +48,20 @@ class QueryExecutor:
else:
self._ensure_columns(table, parsed.columns_by_table[table])
def _full_satisfied(self, table: str) -> bool:
"""True if *table* is cached in full and not TTL-expired (a SELECT * hit)."""
return (
self._cache.is_table_cached(table)
and self._cache.is_table_full(table)
and not self._ttl_expired(table)
)
def _columns_satisfied(self, table: str, columns: list[str]) -> bool:
"""True if *table* is cached with all *columns* present and not TTL-expired."""
if not self._cache.is_table_cached(table) or self._ttl_expired(table):
return False
return set(columns).issubset(self._cache.get_table_columns(table))
def _ensure_full(self, table: str) -> None:
"""Load every column of *table* (SELECT * / t.*), refetching unless already full."""
cached = self._cache.is_table_cached(table)
@@ -67,7 +82,7 @@ class QueryExecutor:
self._stats.record_miss()
columns = self._cache.discover_columns(table, self._source_conn)
self._load(table, columns, full=True)
self._load(table, columns, full=True, satisfied=lambda cols: self._full_satisfied(table))
def _ensure_columns(self, table: str, columns: list[str]) -> None:
"""Load *table* with at least *columns*, refetching on new columns or TTL expiry."""
@@ -95,10 +110,27 @@ class QueryExecutor:
all_columns = list(self._registry.get_columns(table)) + missing
# Preserve a fully-cached table's status across a TTL reload.
full = table_cached and self._cache.is_table_full(table)
self._load(table, all_columns, full=full)
self._load(
table,
all_columns,
full=full,
satisfied=lambda cols: self._columns_satisfied(table, cols),
)
def _load(self, table: str, columns: list[str], full: bool) -> None:
"""Fetch *table* into cache, adding delta key/timestamp and index columns."""
def _load(
self,
table: str,
columns: list[str],
full: bool,
satisfied: Callable[[list[str]], bool] | None = None,
) -> None:
"""Fetch *table* into cache, adding delta key/timestamp and index columns.
*satisfied* is the double-checked-locking predicate evaluated under the
load lock (see :meth:`CacheManager.load_table`); it is given the final,
augmented column list so a concurrent loader that already produced an
equivalent (or wider) cache is detected and the redundant reload skipped.
"""
cfg = self._delta.get(table)
extra = list(self._index_columns.get(table, []))
if cfg:
@@ -108,7 +140,11 @@ class QueryExecutor:
if extra:
columns = list(dict.fromkeys([*columns, *extra]))
self._cache.load_table(table, columns, self._source_conn, full=full)
recheck: Callable[[], bool] | None = None
if satisfied is not None:
final_columns = columns
recheck = lambda: satisfied(final_columns) # noqa: E731
self._cache.load_table(table, columns, self._source_conn, full=full, recheck=recheck)
self._registry.update(table, columns)
if cfg: