Batch large-table loads to bound memory and add per-table state to stats
This commit is contained in:
@@ -6,6 +6,20 @@ All notable changes to this project will be documented in this file.
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## [1.5.0] - 2026-06-05
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- **Per-table processing state in `stats`** — `TableStats` now carries `state` (`loading` / `refreshing` / `ready` / `stale` / `error`) and `tracking` (`delta` / `ttl` / `static`), so callers can see whether each table is up to date or being processed. In-progress first loads and failed loads also surface in `stats.tables`.
|
||||||
|
- `SQLMEM_FETCH_BATCH` env var (default `10000`) — rows fetched per batch when loading a table.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- `pyproject.toml` — bumped version to `1.5.0`
|
||||||
|
- **Large-table loads are streamed in batches** — `load_table` no longer `fetchall()`s the whole table (which double-buffered every row in Python and could OOM/crash on tens of millions of rows). Rows are now fetched `SQLMEM_FETCH_BATCH` at a time into a staging table and swapped in atomically, so peak memory stays bounded, the previous copy stays queryable during a reload, and the network fetch no longer holds the cache lock. Delta catch-ups are streamed the same way.
|
||||||
|
- Orphan staging tables left by an interrupted load (crash/backup mid-load) are dropped on startup.
|
||||||
|
- Delta upserts compute `row_count` once per refresh instead of a full `COUNT(*)` after every batch (avoids O(rows×batches) work on large catch-ups).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## [1.4.0] - 2026-06-05
|
## [1.4.0] - 2026-06-05
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|||||||
@@ -245,9 +245,33 @@ Use `reset()` after a **structural change** in the source (columns added/removed
|
|||||||
stats = engine.stats # Stats snapshot
|
stats = engine.stats # Stats snapshot
|
||||||
print(stats.hits, stats.misses, stats.refetches)
|
print(stats.hits, stats.misses, stats.refetches)
|
||||||
for name, t in stats.tables.items():
|
for name, t in stats.tables.items():
|
||||||
print(name, t.rows, t.columns, t.last_refresh)
|
print(name, t.rows, t.state, t.tracking, t.last_refresh)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Each `TableStats` reports a live processing **state** and how the table is kept fresh (**tracking**):
|
||||||
|
|
||||||
|
| `state` | Meaning |
|
||||||
|
|---|---|
|
||||||
|
| `loading` | a full load is in progress |
|
||||||
|
| `refreshing` | an incremental (delta) refresh is in progress |
|
||||||
|
| `ready` | cached and idle (up to date) |
|
||||||
|
| `stale` | a TTL table whose cache has expired; reloads on next access |
|
||||||
|
| `error` | the last load failed |
|
||||||
|
|
||||||
|
| `tracking` | Meaning |
|
||||||
|
|---|---|
|
||||||
|
| `delta` | kept in sync incrementally via a change column |
|
||||||
|
| `ttl` | full-reloaded when older than its TTL |
|
||||||
|
| `static` | loaded on demand, never auto-refreshed |
|
||||||
|
|
||||||
|
## Memory and very large tables
|
||||||
|
|
||||||
|
The cache is **in-memory SQLite**, so a cached table lives in RAM — it must fit in available memory. To keep huge tables manageable:
|
||||||
|
|
||||||
|
- **Loads are streamed in batches** (`SQLMEM_FETCH_BATCH` rows at a time, default 10 000) into a staging table and swapped in atomically. A multi-million-row table never gets fully materialized in Python at once, so the load doesn't spike memory or crash the process, and readers keep seeing the previous copy until the swap completes.
|
||||||
|
- Use **[delta refresh](#incremental-delta-refresh)** for large tables that have a change column — after the first load only changed rows are pulled, so restarts and refreshes don't re-read the whole table.
|
||||||
|
- A **single query that returns a huge result set** (e.g. `SELECT *` over a multi-million-row cached table) still materializes that result as a list of dicts; bound it with a `WHERE`/`LIMIT` rather than selecting everything.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
Set via environment variables or a `.env` file:
|
Set via environment variables or a `.env` file:
|
||||||
@@ -259,6 +283,7 @@ Set via environment variables or a `.env` file:
|
|||||||
| `SQLMEM_BACKUP_INTERVAL` | `3600` | Disk backup interval in seconds |
|
| `SQLMEM_BACKUP_INTERVAL` | `3600` | Disk backup interval in seconds |
|
||||||
| `SQLMEM_SQL_DIALECT` | `tsql` | sqlglot dialect used to parse incoming SQL (e.g. `tsql`, `postgres`, `mysql`) |
|
| `SQLMEM_SQL_DIALECT` | `tsql` | sqlglot dialect used to parse incoming SQL (e.g. `tsql`, `postgres`, `mysql`) |
|
||||||
| `SQLMEM_REFRESH_INTERVAL` | `300` | background refresh tick (seconds) — delta pulls and proactive TTL reloads |
|
| `SQLMEM_REFRESH_INTERVAL` | `300` | background refresh tick (seconds) — delta pulls and proactive TTL reloads |
|
||||||
|
| `SQLMEM_FETCH_BATCH` | `10000` | rows fetched per batch when loading a table — caps peak memory for huge tables |
|
||||||
|
|
||||||
## Exceptions
|
## Exceptions
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "sqlmem"
|
name = "sqlmem"
|
||||||
version = "1.4.0"
|
version = "1.5.0"
|
||||||
description = ""
|
description = ""
|
||||||
authors = [
|
authors = [
|
||||||
{name = "jan.doubravsky@gmail.com"}
|
{name = "jan.doubravsky@gmail.com"}
|
||||||
|
|||||||
+80
-14
@@ -9,6 +9,8 @@ from loguru import logger
|
|||||||
|
|
||||||
import sqlmem._meta as _meta
|
import sqlmem._meta as _meta
|
||||||
from ._coerce import coerce_params, coerce_row
|
from ._coerce import coerce_params, coerce_row
|
||||||
|
from .config import FETCH_BATCH_SIZE
|
||||||
|
from .stats import TableState
|
||||||
|
|
||||||
SCHEMA_VERSION = 3
|
SCHEMA_VERSION = 3
|
||||||
|
|
||||||
@@ -18,11 +20,14 @@ class CacheManager:
|
|||||||
self._db_path = db_path
|
self._db_path = db_path
|
||||||
self._backup_interval = backup_interval
|
self._backup_interval = backup_interval
|
||||||
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
|
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock() # serializes connection access
|
||||||
|
self._load_lock = threading.Lock() # serializes full table loads
|
||||||
|
self._states: dict[str, str] = {} # table → live processing state
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
self._ensure_meta_tables()
|
self._ensure_meta_tables()
|
||||||
self._load_from_disk()
|
self._load_from_disk()
|
||||||
|
self._drop_orphan_staging()
|
||||||
self._start_backup_thread()
|
self._start_backup_thread()
|
||||||
|
|
||||||
atexit.register(self._backup_to_disk)
|
atexit.register(self._backup_to_disk)
|
||||||
@@ -88,6 +93,21 @@ class CacheManager:
|
|||||||
finally:
|
finally:
|
||||||
disk_conn.close()
|
disk_conn.close()
|
||||||
|
|
||||||
|
def _drop_orphan_staging(self) -> None:
|
||||||
|
"""Drop staging tables left by a load that was interrupted (e.g. crash mid-load)."""
|
||||||
|
orphans = [
|
||||||
|
r[0]
|
||||||
|
for r in self._mem_conn.execute(
|
||||||
|
r"SELECT name FROM sqlite_master "
|
||||||
|
r"WHERE type = 'table' AND name LIKE '%\_\_sqlmem\_load' ESCAPE '\'"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
for name in orphans:
|
||||||
|
logger.warning(f"Dropping orphan staging table {name!r} from a previous interrupted load.")
|
||||||
|
self._mem_conn.execute(f"DROP TABLE IF EXISTS {name}")
|
||||||
|
if orphans:
|
||||||
|
self._mem_conn.commit()
|
||||||
|
|
||||||
def _backup_to_disk(self) -> None:
|
def _backup_to_disk(self) -> None:
|
||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
@@ -161,6 +181,15 @@ class CacheManager:
|
|||||||
logger.debug(f"{table!r} has columns: {columns}")
|
logger.debug(f"{table!r} has columns: {columns}")
|
||||||
return columns
|
return columns
|
||||||
|
|
||||||
|
def set_state(self, table: str, state: str) -> None:
|
||||||
|
self._states[table] = state
|
||||||
|
|
||||||
|
def get_states(self) -> dict[str, str]:
|
||||||
|
return dict(self._states)
|
||||||
|
|
||||||
|
def clear_state(self, table: str) -> None:
|
||||||
|
self._states.pop(table, None)
|
||||||
|
|
||||||
def load_table(
|
def load_table(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
@@ -168,21 +197,55 @@ class CacheManager:
|
|||||||
source_conn: sqlite3.Connection,
|
source_conn: sqlite3.Connection,
|
||||||
full: bool = False,
|
full: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
cols = ", ".join(columns)
|
"""Stream the source table into the cache in batches.
|
||||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
|
|
||||||
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
|
|
||||||
clean_rows = [coerce_row(row) for row in rows]
|
|
||||||
|
|
||||||
with self._lock:
|
Rows are fetched ``FETCH_BATCH_SIZE`` at a time into a private staging
|
||||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
table and swapped in atomically, so peak memory stays bounded (no
|
||||||
|
``fetchall`` of a huge table) and readers keep seeing the previous copy
|
||||||
|
until the swap. Concurrent loads are serialized by ``_load_lock``; the
|
||||||
|
connection lock is only held for the brief per-batch inserts and the swap.
|
||||||
|
"""
|
||||||
|
cols = ", ".join(columns)
|
||||||
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
||||||
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
|
|
||||||
placeholders = ", ".join("?" * len(columns))
|
placeholders = ", ".join("?" * len(columns))
|
||||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", clean_rows)
|
staging = f"{table}__sqlmem_load"
|
||||||
|
|
||||||
|
with self._load_lock:
|
||||||
|
self.set_state(table, TableState.LOADING)
|
||||||
|
logger.info(f"Fetching {table!r} columns [{cols}] from source DB (batch={FETCH_BATCH_SIZE})")
|
||||||
|
try:
|
||||||
|
cursor = source_conn.execute(f"SELECT {cols} FROM {table}")
|
||||||
|
with self._lock:
|
||||||
|
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
|
||||||
|
self._mem_conn.execute(f"CREATE TABLE {staging} ({col_defs})")
|
||||||
self._mem_conn.commit()
|
self._mem_conn.commit()
|
||||||
|
|
||||||
self.mark_table_refreshed(table, len(rows), full)
|
total = 0
|
||||||
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
|
insert_sql = f"INSERT INTO {staging} VALUES ({placeholders})"
|
||||||
|
while True:
|
||||||
|
batch = cursor.fetchmany(FETCH_BATCH_SIZE) # network outside _lock
|
||||||
|
if not batch:
|
||||||
|
break
|
||||||
|
clean = [coerce_row(row) for row in batch]
|
||||||
|
with self._lock:
|
||||||
|
self._mem_conn.executemany(insert_sql, clean)
|
||||||
|
self._mem_conn.commit()
|
||||||
|
total += len(batch)
|
||||||
|
|
||||||
|
with self._lock: # atomic swap — readers see old or new, never partial
|
||||||
|
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
||||||
|
self._mem_conn.execute(f"ALTER TABLE {staging} RENAME TO {table}")
|
||||||
|
self._mem_conn.commit()
|
||||||
|
except BaseException:
|
||||||
|
with self._lock:
|
||||||
|
self._mem_conn.execute(f"DROP TABLE IF EXISTS {staging}")
|
||||||
|
self._mem_conn.commit()
|
||||||
|
self.set_state(table, TableState.ERROR)
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.mark_table_refreshed(table, total, full)
|
||||||
|
self.set_state(table, TableState.READY)
|
||||||
|
logger.info(f"Table {table!r} cached ({total} rows, columns: {columns})")
|
||||||
|
|
||||||
def execute_in_memory(
|
def execute_in_memory(
|
||||||
self, sql: str, params: tuple | list | dict | None = None
|
self, sql: str, params: tuple | list | dict | None = None
|
||||||
@@ -232,7 +295,7 @@ class CacheManager:
|
|||||||
return row[0] if row else None
|
return row[0] if row else None
|
||||||
|
|
||||||
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
|
def upsert_rows(self, table: str, columns: list[str], rows: list[tuple]) -> None:
|
||||||
"""Insert-or-replace *rows* by the table's unique key, then refresh row_count."""
|
"""Insert-or-replace one batch of *rows* by the table's unique key."""
|
||||||
col_list = ", ".join(columns)
|
col_list = ", ".join(columns)
|
||||||
placeholders = ", ".join("?" * len(columns))
|
placeholders = ", ".join("?" * len(columns))
|
||||||
clean_rows = [coerce_row(row) for row in rows]
|
clean_rows = [coerce_row(row) for row in rows]
|
||||||
@@ -242,8 +305,10 @@ class CacheManager:
|
|||||||
clean_rows,
|
clean_rows,
|
||||||
)
|
)
|
||||||
self._mem_conn.commit()
|
self._mem_conn.commit()
|
||||||
count = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
|
|
||||||
self.mark_table_refreshed(table, count, self.is_table_full(table))
|
def count_rows(self, table: str) -> int:
|
||||||
|
row = self._mem_conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
|
||||||
|
return int(row[0]) if row else 0
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Wipe the entire cache — every cached table plus the on-disk file."""
|
"""Wipe the entire cache — every cached table plus the on-disk file."""
|
||||||
@@ -262,6 +327,7 @@ class CacheManager:
|
|||||||
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
|
self._mem_conn.execute("DELETE FROM _sqlmem_tables")
|
||||||
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
|
self._mem_conn.execute("DELETE FROM _sqlmem_columns")
|
||||||
self._mem_conn.commit()
|
self._mem_conn.commit()
|
||||||
|
self._states.clear()
|
||||||
try:
|
try:
|
||||||
if self._db_path.exists():
|
if self._db_path.exists():
|
||||||
self._db_path.unlink()
|
self._db_path.unlink()
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
|
|||||||
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
|
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
|
||||||
# How often (seconds) the background thread pulls deltas for delta-tracked tables.
|
# How often (seconds) the background thread pulls deltas for delta-tracked tables.
|
||||||
REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300"))
|
REFRESH_INTERVAL_SECONDS = int(os.getenv("SQLMEM_REFRESH_INTERVAL", "300"))
|
||||||
|
# Rows fetched per batch when loading a table — caps peak memory for huge tables.
|
||||||
|
FETCH_BATCH_SIZE = int(os.getenv("SQLMEM_FETCH_BATCH", "10000"))
|
||||||
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
|
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
|
||||||
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
|
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
|
||||||
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
|
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
|
||||||
|
|||||||
+24
-6
@@ -4,6 +4,8 @@ from dataclasses import dataclass, field
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .cache import CacheManager
|
from .cache import CacheManager
|
||||||
|
from .config import FETCH_BATCH_SIZE
|
||||||
|
from .stats import TableState
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -58,21 +60,37 @@ class DeltaRefresher:
|
|||||||
col_list = ", ".join(columns)
|
col_list = ", ".join(columns)
|
||||||
|
|
||||||
if watermark is None:
|
if watermark is None:
|
||||||
rows = source_conn.execute(f"SELECT {col_list} FROM {table}").fetchall()
|
cursor = source_conn.execute(f"SELECT {col_list} FROM {table}")
|
||||||
else:
|
else:
|
||||||
rows = source_conn.execute(
|
cursor = source_conn.execute(
|
||||||
f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?",
|
f"SELECT {col_list} FROM {table} WHERE {cfg.change_column} >= ?",
|
||||||
(watermark,),
|
(watermark,),
|
||||||
).fetchall()
|
)
|
||||||
|
|
||||||
if not rows:
|
# Stream the delta in batches so a large catch-up never materializes at once.
|
||||||
|
total = 0
|
||||||
|
self._cache.set_state(table, TableState.REFRESHING)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
batch = cursor.fetchmany(FETCH_BATCH_SIZE)
|
||||||
|
if not batch:
|
||||||
|
break
|
||||||
|
self._cache.upsert_rows(table, columns, batch)
|
||||||
|
total += len(batch)
|
||||||
|
finally:
|
||||||
|
self._cache.set_state(table, TableState.READY)
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}")
|
logger.debug(f"Delta refresh {table!r}: no changes since {watermark!r}")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._cache.upsert_rows(table, columns, rows)
|
# Update row_count / last_refresh once (not per batch) and advance the watermark.
|
||||||
|
self._cache.mark_table_refreshed(
|
||||||
|
table, self._cache.count_rows(table), self._cache.is_table_full(table)
|
||||||
|
)
|
||||||
new_watermark = self._cache.max_value(table, cfg.change_column)
|
new_watermark = self._cache.max_value(table, cfg.change_column)
|
||||||
self._cache.set_last_synced_at(table, new_watermark)
|
self._cache.set_last_synced_at(table, new_watermark)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Delta refresh {table!r}: {len(rows)} row(s) upserted, "
|
f"Delta refresh {table!r}: {total} row(s) upserted, "
|
||||||
f"watermark {watermark!r} → {new_watermark!r}"
|
f"watermark {watermark!r} → {new_watermark!r}"
|
||||||
)
|
)
|
||||||
|
|||||||
+22
-2
@@ -1,5 +1,6 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
|
from dataclasses import replace
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -12,7 +13,7 @@ from .delta import DeltaConfig, DeltaRefresher, ResolvedDelta
|
|||||||
from .executor import QueryExecutor
|
from .executor import QueryExecutor
|
||||||
from .parser import Params, parse
|
from .parser import Params, parse
|
||||||
from .registry import ColumnRegistry
|
from .registry import ColumnRegistry
|
||||||
from .stats import Stats, StatsCollector
|
from .stats import Stats, StatsCollector, TableState, TableStats
|
||||||
|
|
||||||
|
|
||||||
class CachingEngine:
|
class CachingEngine:
|
||||||
@@ -68,8 +69,26 @@ class CachingEngine:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def stats(self) -> Stats:
|
def stats(self) -> Stats:
|
||||||
|
states = self._cache.get_states()
|
||||||
with self._cache._lock:
|
with self._cache._lock:
|
||||||
return self._stats.snapshot(self._cache.connection)
|
base = self._stats.snapshot(self._cache.connection, states)
|
||||||
|
return replace(base, tables={n: self._enrich(n, t) for n, t in base.tables.items()})
|
||||||
|
|
||||||
|
def _enrich(self, name: str, table_stats: TableStats) -> TableStats:
|
||||||
|
"""Annotate a TableStats with how it is refreshed and TTL staleness."""
|
||||||
|
if name in self._delta:
|
||||||
|
tracking = "delta"
|
||||||
|
elif name in self._ttl:
|
||||||
|
tracking = "ttl"
|
||||||
|
else:
|
||||||
|
tracking = "static"
|
||||||
|
|
||||||
|
state = table_stats.state
|
||||||
|
if state == TableState.READY and name in self._ttl:
|
||||||
|
age = self._cache.seconds_since_refresh(name)
|
||||||
|
if age is not None and age > self._ttl[name]:
|
||||||
|
state = TableState.STALE
|
||||||
|
return replace(table_stats, tracking=tracking, state=state)
|
||||||
|
|
||||||
def execute(self, sql: str, params: Params = None) -> list[dict]:
|
def execute(self, sql: str, params: Params = None) -> list[dict]:
|
||||||
parsed = parse(sql, params)
|
parsed = parse(sql, params)
|
||||||
@@ -130,6 +149,7 @@ class CachingEngine:
|
|||||||
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
|
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
|
||||||
)
|
)
|
||||||
self._cache.connection.commit()
|
self._cache.connection.commit()
|
||||||
|
self._cache.clear_state(table)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
|
"""Wipe the whole cache (RAM + cache.db). Use after structural source changes."""
|
||||||
|
|||||||
+24
-1
@@ -3,11 +3,23 @@ import threading
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class TableState:
|
||||||
|
"""Live processing state of a cached table (value of ``TableStats.state``)."""
|
||||||
|
|
||||||
|
LOADING = "loading" # a full load is in progress
|
||||||
|
REFRESHING = "refreshing" # an incremental (delta) refresh is in progress
|
||||||
|
READY = "ready" # cached and idle
|
||||||
|
STALE = "stale" # TTL expired — will reload on next access
|
||||||
|
ERROR = "error" # the last load failed
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TableStats:
|
class TableStats:
|
||||||
rows: int
|
rows: int
|
||||||
columns: list[str]
|
columns: list[str]
|
||||||
last_refresh: str
|
last_refresh: str
|
||||||
|
state: str = TableState.READY
|
||||||
|
tracking: str = "static" # "delta" | "ttl" | "static"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -37,14 +49,19 @@ class StatsCollector:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self.refetches += 1
|
self.refetches += 1
|
||||||
|
|
||||||
def snapshot(self, conn: sqlite3.Connection) -> Stats:
|
def snapshot(
|
||||||
|
self, conn: sqlite3.Connection, states: dict[str, str] | None = None
|
||||||
|
) -> Stats:
|
||||||
|
states = states or {}
|
||||||
with self._lock:
|
with self._lock:
|
||||||
hits, misses, refetches = self.hits, self.misses, self.refetches
|
hits, misses, refetches = self.hits, self.misses, self.refetches
|
||||||
|
|
||||||
tables: dict[str, TableStats] = {}
|
tables: dict[str, TableStats] = {}
|
||||||
|
cached: set[str] = set()
|
||||||
for table_name, row_count, last_refresh in conn.execute(
|
for table_name, row_count, last_refresh in conn.execute(
|
||||||
"SELECT table_name, row_count, last_refresh_at FROM _sqlmem_tables"
|
"SELECT table_name, row_count, last_refresh_at FROM _sqlmem_tables"
|
||||||
).fetchall():
|
).fetchall():
|
||||||
|
cached.add(table_name)
|
||||||
columns = [
|
columns = [
|
||||||
r[0]
|
r[0]
|
||||||
for r in conn.execute(
|
for r in conn.execute(
|
||||||
@@ -56,6 +73,12 @@ class StatsCollector:
|
|||||||
rows=row_count or 0,
|
rows=row_count or 0,
|
||||||
columns=columns,
|
columns=columns,
|
||||||
last_refresh=last_refresh,
|
last_refresh=last_refresh,
|
||||||
|
state=states.get(table_name, TableState.READY),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Surface tables that are mid-first-load (not yet in _sqlmem_tables) or failed.
|
||||||
|
for name, state in states.items():
|
||||||
|
if name not in cached and state in (TableState.LOADING, TableState.ERROR):
|
||||||
|
tables[name] = TableStats(rows=0, columns=[], last_refresh="", state=state)
|
||||||
|
|
||||||
return Stats(hits=hits, misses=misses, refetches=refetches, tables=tables)
|
return Stats(hits=hits, misses=misses, refetches=refetches, tables=tables)
|
||||||
|
|||||||
+10
-2
@@ -10,11 +10,19 @@ from sqlmem.cache import CacheManager
|
|||||||
|
|
||||||
class _FakeCursor:
|
class _FakeCursor:
|
||||||
def __init__(self, rows):
|
def __init__(self, rows):
|
||||||
self._rows = rows
|
self._rows = list(rows)
|
||||||
|
self._pos = 0
|
||||||
self.description = None
|
self.description = None
|
||||||
|
|
||||||
def fetchall(self):
|
def fetchall(self):
|
||||||
return self._rows
|
out = self._rows[self._pos :]
|
||||||
|
self._pos = len(self._rows)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fetchmany(self, size):
|
||||||
|
out = self._rows[self._pos : self._pos + size]
|
||||||
|
self._pos += len(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FakeSource:
|
class FakeSource:
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from sqlmem.cache import CacheManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def source_conn():
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
conn.execute("CREATE TABLE big (id TEXT, val TEXT)")
|
||||||
|
conn.executemany(
|
||||||
|
"INSERT INTO big VALUES (?, ?)", [(str(i), f"v{i}") for i in range(5)]
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
yield conn
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cache(tmp_path):
|
||||||
|
c = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
|
||||||
|
yield c
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def small_batches(monkeypatch):
|
||||||
|
# Force multiple fetch batches over the 5 source rows.
|
||||||
|
monkeypatch.setattr("sqlmem.cache.FETCH_BATCH_SIZE", 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batched_load_loads_all_rows(cache, source_conn, small_batches):
|
||||||
|
cache.load_table("big", ["id", "val"], source_conn)
|
||||||
|
_, rows = cache.execute_in_memory(
|
||||||
|
"SELECT id, val FROM big ORDER BY CAST(id AS INTEGER)"
|
||||||
|
)
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert rows[0] == ("0", "v0")
|
||||||
|
assert rows[-1] == ("4", "v4")
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_staging_table_left_behind(cache, source_conn, small_batches):
|
||||||
|
cache.load_table("big", ["id", "val"], source_conn)
|
||||||
|
names = {
|
||||||
|
r[0]
|
||||||
|
for r in cache.connection.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||||
|
).fetchall()
|
||||||
|
}
|
||||||
|
assert "big" in names
|
||||||
|
assert not any(n.endswith("__sqlmem_load") for n in names)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reload_replaces_data_atomically(cache, source_conn, small_batches):
|
||||||
|
cache.load_table("big", ["id", "val"], source_conn)
|
||||||
|
source_conn.execute("DELETE FROM big")
|
||||||
|
source_conn.execute("INSERT INTO big VALUES ('99', 'new')")
|
||||||
|
source_conn.commit()
|
||||||
|
cache.load_table("big", ["id", "val"], source_conn)
|
||||||
|
_, rows = cache.execute_in_memory("SELECT id, val FROM big")
|
||||||
|
assert rows == [("99", "new")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_sets_ready_state(cache, source_conn):
|
||||||
|
cache.load_table("big", ["id", "val"], source_conn)
|
||||||
|
assert cache.get_states()["big"] == "ready"
|
||||||
|
|
||||||
|
|
||||||
|
def test_orphan_staging_dropped_on_startup(tmp_path, source_conn):
|
||||||
|
# Simulate a crash mid-load: a staging table persisted into cache.db.
|
||||||
|
db_path = tmp_path / "cache.db"
|
||||||
|
c1 = CacheManager(db_path=db_path, backup_interval=9999)
|
||||||
|
c1.load_table("big", ["id", "val"], source_conn)
|
||||||
|
c1.connection.execute("CREATE TABLE big__sqlmem_load (id TEXT, val TEXT)")
|
||||||
|
c1.connection.commit()
|
||||||
|
c1.close() # backup writes the staging table to disk
|
||||||
|
|
||||||
|
c2 = CacheManager(db_path=db_path, backup_interval=9999)
|
||||||
|
names = {
|
||||||
|
r[0]
|
||||||
|
for r in c2.connection.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||||
|
).fetchall()
|
||||||
|
}
|
||||||
|
c2.close()
|
||||||
|
assert "big" in names # real table survives
|
||||||
|
assert not any(n.endswith("__sqlmem_load") for n in names) # orphan cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def test_failed_load_sets_error_state_and_cleans_staging(cache):
|
||||||
|
empty_source = sqlite3.connect(":memory:") # has no 'big' table
|
||||||
|
try:
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
cache.load_table("big", ["id"], empty_source)
|
||||||
|
assert cache.get_states()["big"] == "error"
|
||||||
|
names = {
|
||||||
|
r[0]
|
||||||
|
for r in cache.connection.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||||
|
).fetchall()
|
||||||
|
}
|
||||||
|
assert not any(n.endswith("__sqlmem_load") for n in names)
|
||||||
|
finally:
|
||||||
|
empty_source.close()
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
import sqlmem.engine as eng_mod
|
||||||
|
from sqlmem import CachingEngine, DeltaConfig
|
||||||
|
from sqlmem.cache import CacheManager
|
||||||
|
from sqlmem.stats import StatsCollector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def source_engine(tmp_path):
|
||||||
|
db_path = tmp_path / "source.db"
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.executescript(
|
||||||
|
"""
|
||||||
|
CREATE TABLE products (id TEXT PRIMARY KEY, name TEXT, changed TEXT);
|
||||||
|
INSERT INTO products VALUES ('1', 'Widget', '2026-06-01 10:00:00');
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
engine = create_engine(f"sqlite:///{db_path}")
|
||||||
|
yield engine
|
||||||
|
engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_cache(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setattr(eng_mod, "CACHE_DB_PATH", tmp_path / "cache.db")
|
||||||
|
monkeypatch.setattr(eng_mod, "BACKUP_INTERVAL_SECONDS", 9999)
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_table_state_and_tracking(source_engine, patched_cache):
|
||||||
|
engine = CachingEngine(source_engine)
|
||||||
|
engine.execute("SELECT id, name FROM products")
|
||||||
|
s = engine.stats.tables["products"]
|
||||||
|
assert s.state == "ready"
|
||||||
|
assert s.tracking == "static"
|
||||||
|
assert s.rows == 1
|
||||||
|
engine.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delta_table_tracking(source_engine, patched_cache):
|
||||||
|
engine = CachingEngine(
|
||||||
|
source_engine, delta={"products": DeltaConfig("changed", ["id"])}
|
||||||
|
)
|
||||||
|
engine.execute("SELECT id, name FROM products")
|
||||||
|
s = engine.stats.tables["products"]
|
||||||
|
assert s.tracking == "delta"
|
||||||
|
assert s.state == "ready"
|
||||||
|
engine.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ttl_table_reports_stale(source_engine, patched_cache):
|
||||||
|
engine = CachingEngine(source_engine, ttl={"products": 0})
|
||||||
|
engine.execute("SELECT id, name FROM products")
|
||||||
|
s = engine.stats.tables["products"]
|
||||||
|
assert s.tracking == "ttl"
|
||||||
|
assert s.state == "stale" # ttl=0 → already past its max age
|
||||||
|
engine.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_counters_still_reported(source_engine, patched_cache):
|
||||||
|
engine = CachingEngine(source_engine)
|
||||||
|
engine.execute("SELECT id, name FROM products")
|
||||||
|
engine.execute("SELECT id, name FROM products")
|
||||||
|
stats = engine.stats
|
||||||
|
assert stats.misses == 1
|
||||||
|
assert stats.hits == 1
|
||||||
|
engine.close()
|
||||||
|
|
||||||
|
|
||||||
|
# --- a table being loaded for the first time shows up as "loading" ----------
|
||||||
|
|
||||||
|
|
||||||
|
def test_snapshot_surfaces_a_loading_table(tmp_path):
|
||||||
|
cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
|
||||||
|
snap = StatsCollector().snapshot(cache.connection, {"pending": "loading"})
|
||||||
|
assert "pending" in snap.tables
|
||||||
|
assert snap.tables["pending"].state == "loading"
|
||||||
|
assert snap.tables["pending"].rows == 0
|
||||||
|
cache.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_loading_state_visible_from_another_thread_during_load(tmp_path):
|
||||||
|
"""A first load in progress is observable as 'loading' from another thread."""
|
||||||
|
cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
|
||||||
|
started = threading.Event()
|
||||||
|
release = threading.Event()
|
||||||
|
|
||||||
|
class BlockingCursor:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = list(rows)
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
def fetchmany(self, size):
|
||||||
|
if self._done:
|
||||||
|
return []
|
||||||
|
started.set()
|
||||||
|
release.wait(5) # hold the load open until the test releases it
|
||||||
|
self._done = True
|
||||||
|
return self._rows
|
||||||
|
|
||||||
|
class BlockingSource:
|
||||||
|
def execute(self, sql):
|
||||||
|
return BlockingCursor([("1", "alice")])
|
||||||
|
|
||||||
|
loader = threading.Thread(
|
||||||
|
target=cache.load_table, args=("users", ["id", "name"], BlockingSource())
|
||||||
|
)
|
||||||
|
loader.start()
|
||||||
|
try:
|
||||||
|
assert started.wait(5), "load did not start"
|
||||||
|
# mid-load: not yet in _sqlmem_tables, but surfaced as loading
|
||||||
|
assert cache.get_states()["users"] == "loading"
|
||||||
|
snap = StatsCollector().snapshot(cache.connection, cache.get_states())
|
||||||
|
assert snap.tables["users"].state == "loading"
|
||||||
|
finally:
|
||||||
|
release.set()
|
||||||
|
loader.join(5)
|
||||||
|
assert not loader.is_alive()
|
||||||
|
assert cache.get_states()["users"] == "ready"
|
||||||
|
cache.close()
|
||||||
Reference in New Issue
Block a user