Add initial SQLmem package structure with SQL parser, cache manager, column registry, and tests
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from .engine import CachingEngine
|
||||
from .exceptions import ReadOnlyError, UnsupportedQueryError
|
||||
|
||||
__all__ = ["CachingEngine", "ReadOnlyError", "UnsupportedQueryError"]
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,158 @@
|
||||
import atexit
|
||||
import signal
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import sqlmem._meta as _meta
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(self, db_path: Path, backup_interval: int) -> None:
|
||||
self._db_path = db_path
|
||||
self._backup_interval = backup_interval
|
||||
self._mem_conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
self._lock = threading.Lock()
|
||||
self._closed = False
|
||||
|
||||
self._ensure_meta_tables()
|
||||
self._load_from_disk()
|
||||
self._start_backup_thread()
|
||||
|
||||
atexit.register(self._backup_to_disk)
|
||||
signal.signal(signal.SIGTERM, self._on_sigterm)
|
||||
|
||||
@property
|
||||
def connection(self) -> sqlite3.Connection:
|
||||
return self._mem_conn
|
||||
|
||||
def _ensure_meta_tables(self) -> None:
|
||||
self._mem_conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_meta (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_tables (
|
||||
table_name TEXT PRIMARY KEY,
|
||||
last_refresh_at TEXT NOT NULL,
|
||||
row_count INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
|
||||
table_name TEXT NOT NULL,
|
||||
column_name TEXT NOT NULL,
|
||||
PRIMARY KEY (table_name, column_name)
|
||||
);
|
||||
""")
|
||||
self._mem_conn.execute(
|
||||
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
|
||||
("app_version", _meta.__version__),
|
||||
)
|
||||
self._mem_conn.execute(
|
||||
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
|
||||
("schema_version", str(SCHEMA_VERSION)),
|
||||
)
|
||||
self._mem_conn.execute(
|
||||
"INSERT OR IGNORE INTO _sqlmem_meta (key, value) VALUES (?, ?)",
|
||||
("created_at", _now()),
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
|
||||
def _load_from_disk(self) -> None:
|
||||
if not self._db_path.exists():
|
||||
logger.info(f"No cache file found at {self._db_path}, starting fresh.")
|
||||
return
|
||||
|
||||
logger.info(f"Loading cache from {self._db_path}")
|
||||
disk_conn = sqlite3.connect(self._db_path)
|
||||
try:
|
||||
schema_version = disk_conn.execute(
|
||||
"SELECT value FROM _sqlmem_meta WHERE key = 'schema_version'"
|
||||
).fetchone()
|
||||
if schema_version is None or int(schema_version[0]) != SCHEMA_VERSION:
|
||||
logger.warning("Cache schema version mismatch — discarding cache file, starting fresh.")
|
||||
disk_conn.close()
|
||||
return
|
||||
|
||||
disk_conn.backup(self._mem_conn)
|
||||
logger.info("Cache loaded from disk successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cache from disk: {e} — starting fresh.")
|
||||
finally:
|
||||
disk_conn.close()
|
||||
|
||||
def _backup_to_disk(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
logger.info(f"Backing up cache to {self._db_path}")
|
||||
try:
|
||||
with self._lock:
|
||||
disk_conn = sqlite3.connect(self._db_path)
|
||||
self._mem_conn.backup(disk_conn)
|
||||
disk_conn.close()
|
||||
logger.info("Cache backup complete.")
|
||||
except Exception as e:
|
||||
logger.error(f"Cache backup failed: {e}")
|
||||
|
||||
def _start_backup_thread(self) -> None:
|
||||
def loop() -> None:
|
||||
event = threading.Event()
|
||||
while not event.wait(self._backup_interval):
|
||||
self._backup_to_disk()
|
||||
|
||||
t = threading.Thread(target=loop, daemon=True, name="sqlmem-backup")
|
||||
t.start()
|
||||
logger.debug(f"Backup thread started (interval={self._backup_interval}s)")
|
||||
|
||||
def _on_sigterm(self, signum, frame) -> None:
|
||||
logger.info("SIGTERM received — flushing cache to disk.")
|
||||
self._backup_to_disk()
|
||||
|
||||
def mark_table_refreshed(self, table: str, row_count: int) -> None:
|
||||
with self._lock:
|
||||
self._mem_conn.execute(
|
||||
"""
|
||||
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(table_name) DO UPDATE SET
|
||||
last_refresh_at = excluded.last_refresh_at,
|
||||
row_count = excluded.row_count
|
||||
""",
|
||||
(table, _now(), row_count),
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
|
||||
def is_table_cached(self, table: str) -> bool:
|
||||
row = self._mem_conn.execute(
|
||||
"SELECT 1 FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
||||
).fetchone()
|
||||
return row is not None
|
||||
|
||||
def load_table(self, table: str, columns: list[str], source_conn: sqlite3.Connection) -> None:
|
||||
cols = ", ".join(columns)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
|
||||
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
|
||||
|
||||
with self._lock:
|
||||
self._mem_conn.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
col_defs = ", ".join(f"{c} TEXT" for c in columns)
|
||||
self._mem_conn.execute(f"CREATE TABLE {table} ({col_defs})")
|
||||
placeholders = ", ".join("?" * len(columns))
|
||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", rows)
|
||||
self._mem_conn.commit()
|
||||
|
||||
self.mark_table_refreshed(table, len(rows))
|
||||
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
|
||||
|
||||
def close(self) -> None:
|
||||
self._backup_to_disk()
|
||||
self._closed = True
|
||||
self._mem_conn.close()
|
||||
|
||||
|
||||
def _now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
load_dotenv()
|
||||
|
||||
DEBUG = os.getenv("SQLMEM_DEBUG", "false").lower() == "true"
|
||||
CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
|
||||
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
|
||||
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sink=lambda msg: print(msg, end=""),
|
||||
level="DEBUG" if DEBUG else "INFO",
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||
colorize=True,
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
import sqlite3
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from .cache import CacheManager
|
||||
from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH
|
||||
from .executor import QueryExecutor
|
||||
from .parser import parse
|
||||
from .registry import ColumnRegistry
|
||||
|
||||
|
||||
class CachingEngine:
|
||||
"""Transparent SQLAlchemy-compatible cache layer."""
|
||||
|
||||
def __init__(self, source_engine: Engine) -> None:
|
||||
self._source_engine = source_engine
|
||||
self._cache = CacheManager(CACHE_DB_PATH, BACKUP_INTERVAL_SECONDS)
|
||||
self._registry = ColumnRegistry(self._cache.connection)
|
||||
logger.info("CachingEngine initialized.")
|
||||
|
||||
def execute(self, sql: str) -> list[dict]:
|
||||
parsed = parse(sql)
|
||||
with self._source_engine.connect() as sa_conn:
|
||||
raw_conn: sqlite3.Connection = sa_conn.connection.dbapi_connection
|
||||
executor = QueryExecutor(self._cache, self._registry, raw_conn)
|
||||
return executor.execute(parsed)
|
||||
|
||||
def invalidate(self, table: str) -> None:
|
||||
logger.info(f"Manually invalidating cache for table {table!r}")
|
||||
with self._cache._lock:
|
||||
self._cache.connection.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
self._cache.connection.execute(
|
||||
"DELETE FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
||||
)
|
||||
self._cache.connection.execute(
|
||||
"DELETE FROM _sqlmem_columns WHERE table_name = ?", (table,)
|
||||
)
|
||||
self._cache.connection.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
self._cache.close()
|
||||
logger.info("CachingEngine closed.")
|
||||
@@ -0,0 +1,6 @@
|
||||
class ReadOnlyError(Exception):
|
||||
"""Raised when a write operation (INSERT/UPDATE/DELETE) is attempted."""
|
||||
|
||||
|
||||
class UnsupportedQueryError(Exception):
|
||||
"""Raised when a query uses unsupported features (JOIN, SELECT *)."""
|
||||
@@ -0,0 +1,42 @@
|
||||
import sqlite3
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .cache import CacheManager
|
||||
from .parser import ParsedQuery
|
||||
from .registry import ColumnRegistry
|
||||
|
||||
|
||||
class QueryExecutor:
|
||||
def __init__(self, cache: CacheManager, registry: ColumnRegistry, source_conn: sqlite3.Connection) -> None:
|
||||
self._cache = cache
|
||||
self._registry = registry
|
||||
self._source_conn = source_conn
|
||||
|
||||
def execute(self, parsed: ParsedQuery) -> list[dict]:
|
||||
table = parsed.table
|
||||
columns = parsed.columns
|
||||
|
||||
missing = self._registry.needs_refetch(table, columns)
|
||||
table_cached = self._cache.is_table_cached(table)
|
||||
|
||||
if missing or not table_cached:
|
||||
if table_cached and missing:
|
||||
logger.warning(
|
||||
f"Re-fetching {table!r} — new columns requested: {missing}. "
|
||||
f"Expanding cache from {self._registry.get_columns(table)} + {missing}"
|
||||
)
|
||||
all_columns = list(self._registry.get_columns(table)) + missing
|
||||
self._cache.load_table(table, all_columns, self._source_conn)
|
||||
self._registry.update(table, all_columns)
|
||||
else:
|
||||
logger.debug(f"Cache hit: {table!r} columns={columns}")
|
||||
|
||||
return self._run_in_memory(parsed)
|
||||
|
||||
def _run_in_memory(self, parsed: ParsedQuery) -> list[dict]:
|
||||
logger.debug(f"Executing in SQLite RAM: {parsed.original_sql!r}")
|
||||
cursor = self._cache.connection.execute(parsed.original_sql)
|
||||
col_names = [desc[0] for desc in cursor.description]
|
||||
rows = cursor.fetchall()
|
||||
return [dict(zip(col_names, row)) for row in rows]
|
||||
@@ -0,0 +1,71 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import sqlglot
|
||||
import sqlglot.expressions as exp
|
||||
from loguru import logger
|
||||
|
||||
from .exceptions import ReadOnlyError, UnsupportedQueryError
|
||||
|
||||
WRITE_TYPES = (exp.Insert, exp.Update, exp.Delete)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedQuery:
|
||||
table: str
|
||||
columns: list[str]
|
||||
original_sql: str
|
||||
|
||||
|
||||
def parse(sql: str) -> ParsedQuery:
|
||||
logger.debug(f"Parsing SQL: {sql!r}")
|
||||
|
||||
statement = sqlglot.parse_one(sql)
|
||||
|
||||
if isinstance(statement, WRITE_TYPES):
|
||||
raise ReadOnlyError(
|
||||
f"Write operations are not allowed. Attempted: {type(statement).__name__.upper()}"
|
||||
)
|
||||
|
||||
if not isinstance(statement, exp.Select):
|
||||
raise UnsupportedQueryError(f"Only SELECT statements are supported, got: {sql!r}")
|
||||
|
||||
_check_joins(statement)
|
||||
_check_wildcard(statement)
|
||||
|
||||
table = _extract_table(statement)
|
||||
columns = _extract_columns(statement)
|
||||
|
||||
logger.debug(f"Parsed → table={table!r}, columns={columns}")
|
||||
return ParsedQuery(table=table, columns=columns, original_sql=sql)
|
||||
|
||||
|
||||
def _check_joins(statement: exp.Select) -> None:
|
||||
if statement.find(exp.Join):
|
||||
raise UnsupportedQueryError("JOIN is not supported yet. Use simple single-table SELECT.")
|
||||
|
||||
|
||||
def _check_wildcard(statement: exp.Select) -> None:
|
||||
for col in statement.find_all(exp.Column):
|
||||
if isinstance(col.this, exp.Star):
|
||||
raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.")
|
||||
if statement.find(exp.Star):
|
||||
raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.")
|
||||
|
||||
|
||||
def _extract_table(statement: exp.Select) -> str:
|
||||
from_clause = statement.find(exp.From)
|
||||
if not from_clause:
|
||||
raise UnsupportedQueryError("SELECT without FROM is not supported.")
|
||||
table = from_clause.find(exp.Table)
|
||||
if not table:
|
||||
raise UnsupportedQueryError("Could not extract table name from query.")
|
||||
return table.name
|
||||
|
||||
|
||||
def _extract_columns(statement: exp.Select) -> list[str]:
|
||||
columns = []
|
||||
for col in statement.find_all(exp.Column):
|
||||
columns.append(col.name)
|
||||
if not columns:
|
||||
raise UnsupportedQueryError("Could not extract column names from query.")
|
||||
return columns
|
||||
@@ -0,0 +1,48 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ColumnRegistry:
|
||||
"""Tracks which columns per table have been requested and are held in cache."""
|
||||
|
||||
def __init__(self, mem_conn: sqlite3.Connection) -> None:
|
||||
self._conn = mem_conn
|
||||
self._lock = Lock()
|
||||
self._ensure_table()
|
||||
|
||||
def _ensure_table(self) -> None:
|
||||
self._conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
|
||||
table_name TEXT NOT NULL,
|
||||
column_name TEXT NOT NULL,
|
||||
PRIMARY KEY (table_name, column_name)
|
||||
)
|
||||
""")
|
||||
self._conn.commit()
|
||||
|
||||
def get_columns(self, table: str) -> set[str]:
|
||||
rows = self._conn.execute(
|
||||
"SELECT column_name FROM _sqlmem_columns WHERE table_name = ?", (table,)
|
||||
).fetchall()
|
||||
return {row[0] for row in rows}
|
||||
|
||||
def needs_refetch(self, table: str, requested: list[str]) -> list[str]:
|
||||
"""Returns columns that are requested but not yet in registry (missing columns)."""
|
||||
known = self.get_columns(table)
|
||||
missing = [c for c in requested if c not in known]
|
||||
return missing
|
||||
|
||||
def update(self, table: str, columns: list[str]) -> None:
|
||||
with self._lock:
|
||||
existing = self.get_columns(table)
|
||||
new_columns = [c for c in columns if c not in existing]
|
||||
if not new_columns:
|
||||
return
|
||||
self._conn.executemany(
|
||||
"INSERT OR IGNORE INTO _sqlmem_columns (table_name, column_name) VALUES (?, ?)",
|
||||
[(table, col) for col in new_columns],
|
||||
)
|
||||
self._conn.commit()
|
||||
logger.info(f"Registry updated: {table!r} now tracks columns {self.get_columns(table)}")
|
||||
Reference in New Issue
Block a user