Add support for query parameters, JOINs, SELECT * and three-part table names

This commit is contained in:
Jan Doubravský
2026-06-04 18:25:47 +02:00
parent b044ca43f8
commit 530c2618cf
14 changed files with 511 additions and 106 deletions
+32 -9
View File
@@ -9,7 +9,7 @@ from loguru import logger
import sqlmem._meta as _meta
SCHEMA_VERSION = 1
SCHEMA_VERSION = 2
class CacheManager:
@@ -40,7 +40,8 @@ class CacheManager:
CREATE TABLE IF NOT EXISTS _sqlmem_tables (
table_name TEXT PRIMARY KEY,
last_refresh_at TEXT NOT NULL,
row_count INTEGER
row_count INTEGER,
is_full INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
table_name TEXT NOT NULL,
@@ -112,17 +113,18 @@ class CacheManager:
logger.info("SIGTERM received — flushing cache to disk.")
self._backup_to_disk()
def mark_table_refreshed(self, table: str, row_count: int) -> None:
def mark_table_refreshed(self, table: str, row_count: int, full: bool = False) -> None:
with self._lock:
self._mem_conn.execute(
"""
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count)
VALUES (?, ?, ?)
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count, is_full)
VALUES (?, ?, ?, ?)
ON CONFLICT(table_name) DO UPDATE SET
last_refresh_at = excluded.last_refresh_at,
row_count = excluded.row_count
row_count = excluded.row_count,
is_full = excluded.is_full
""",
(table, _now(), row_count),
(table, _now(), row_count, int(full)),
)
self._mem_conn.commit()
@@ -132,7 +134,28 @@ class CacheManager:
).fetchone()
return row is not None
def load_table(self, table: str, columns: list[str], source_conn: sqlite3.Connection) -> None:
def is_table_full(self, table: str) -> bool:
"""True if the whole table (all columns) is cached — a SELECT * cache hit."""
row = self._mem_conn.execute(
"SELECT is_full FROM _sqlmem_tables WHERE table_name = ?", (table,)
).fetchone()
return bool(row and row[0])
def discover_columns(self, table: str, source_conn: sqlite3.Connection) -> list[str]:
"""Return all column names of *table* from the source DB without fetching rows."""
logger.debug(f"Discovering columns of {table!r} from source DB")
cursor = source_conn.execute(f"SELECT * FROM {table} WHERE 1 = 0")
columns = [desc[0] for desc in cursor.description]
logger.debug(f"{table!r} has columns: {columns}")
return columns
def load_table(
self,
table: str,
columns: list[str],
source_conn: sqlite3.Connection,
full: bool = False,
) -> None:
cols = ", ".join(columns)
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
@@ -145,7 +168,7 @@ class CacheManager:
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", rows)
self._mem_conn.commit()
self.mark_table_refreshed(table, len(rows))
self.mark_table_refreshed(table, len(rows), full)
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
def close(self) -> None:
+3
View File
@@ -9,6 +9,9 @@ load_dotenv()
DEBUG = os.getenv("SQLMEM_DEBUG", "false").lower() == "true"
CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
# Silent by default — callers opt in via add_sink().
logger.disable("sqlmem")
+5 -4
View File
@@ -1,4 +1,5 @@
import sqlite3
from typing import cast
from loguru import logger
from sqlalchemy.engine import Engine
@@ -6,7 +7,7 @@ from sqlalchemy.engine import Engine
from .cache import CacheManager
from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH
from .executor import QueryExecutor
from .parser import parse
from .parser import Params, parse
from .registry import ColumnRegistry
from .stats import Stats, StatsCollector
@@ -25,10 +26,10 @@ class CachingEngine:
def stats(self) -> Stats:
return self._stats.snapshot(self._cache.connection)
def execute(self, sql: str) -> list[dict]:
parsed = parse(sql)
def execute(self, sql: str, params: Params = None) -> list[dict]:
parsed = parse(sql, params)
with self._source_engine.connect() as sa_conn:
raw_conn: sqlite3.Connection = sa_conn.connection.dbapi_connection
raw_conn = cast(sqlite3.Connection, sa_conn.connection.dbapi_connection)
executor = QueryExecutor(self._cache, self._registry, raw_conn, self._stats)
return executor.execute(parsed)
+48 -18
View File
@@ -22,33 +22,63 @@ class QueryExecutor:
self._stats = stats
def execute(self, parsed: ParsedQuery) -> list[dict]:
table = parsed.table
columns = parsed.columns
for table in parsed.tables:
self._ensure_table(table, parsed)
return self._run_in_memory(parsed)
def _ensure_table(self, table: str, parsed: ParsedQuery) -> None:
if table in parsed.wildcard_tables:
self._ensure_full(table)
else:
self._ensure_columns(table, parsed.columns_by_table[table])
def _ensure_full(self, table: str) -> None:
"""Load every column of *table* (SELECT * / t.*), refetching unless already full."""
if self._cache.is_table_cached(table) and self._cache.is_table_full(table):
logger.debug(f"Cache hit (full): {table!r}")
self._stats.record_hit()
return
if self._cache.is_table_cached(table):
logger.warning(f"Re-fetching {table!r} in full — SELECT * requested.")
self._stats.record_refetch()
else:
self._stats.record_miss()
columns = self._cache.discover_columns(table, self._source_conn)
self._cache.load_table(table, columns, self._source_conn, full=True)
self._registry.update(table, columns)
def _ensure_columns(self, table: str, columns: list[str]) -> None:
"""Load *table* with at least *columns*, refetching only when columns are missing."""
missing = self._registry.needs_refetch(table, columns)
table_cached = self._cache.is_table_cached(table)
if missing or not table_cached:
if table_cached and missing:
logger.warning(
f"Re-fetching {table!r} — new columns requested: {missing}. "
f"Expanding cache from {self._registry.get_columns(table)} + {missing}"
)
self._stats.record_refetch()
else:
self._stats.record_miss()
all_columns = list(self._registry.get_columns(table)) + missing
self._cache.load_table(table, all_columns, self._source_conn)
self._registry.update(table, all_columns)
else:
if not missing and table_cached:
logger.debug(f"Cache hit: {table!r} columns={columns}")
self._stats.record_hit()
return
return self._run_in_memory(parsed)
if table_cached and missing:
logger.warning(
f"Re-fetching {table!r} — new columns requested: {missing}. "
f"Expanding cache from {self._registry.get_columns(table)} + {missing}"
)
self._stats.record_refetch()
else:
self._stats.record_miss()
all_columns = list(self._registry.get_columns(table)) + missing
self._cache.load_table(table, all_columns, self._source_conn)
self._registry.update(table, all_columns)
def _run_in_memory(self, parsed: ParsedQuery) -> list[dict]:
logger.debug(f"Executing in SQLite RAM: {parsed.original_sql!r}")
cursor = self._cache.connection.execute(parsed.original_sql)
logger.debug(f"Executing in SQLite RAM: {parsed.sqlite_sql!r} params={parsed.params!r}")
conn = self._cache.connection
if parsed.params is None:
cursor = conn.execute(parsed.sqlite_sql)
else:
cursor = conn.execute(parsed.sqlite_sql, parsed.params)
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return [dict(zip(col_names, row)) for row in rows]
+105 -39
View File
@@ -1,25 +1,34 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
import sqlglot
import sqlglot.expressions as exp
from loguru import logger
from .config import SQL_DIALECT
from .exceptions import ReadOnlyError, UnsupportedQueryError
WRITE_TYPES = (exp.Insert, exp.Update, exp.Delete)
SQLITE_DIALECT = "sqlite"
# Parameters accepted by execute(): positional (tuple/list of ``?``) or named (dict of ``:name``).
Params = tuple | list | dict | None
@dataclass
class ParsedQuery:
table: str
columns: list[str]
tables: list[str]
columns_by_table: dict[str, list[str]]
sqlite_sql: str
original_sql: str
params: Params = None
# Tables that must be loaded in full (SELECT * / t.* / referenced without explicit columns).
wildcard_tables: set[str] = field(default_factory=set)
def parse(sql: str) -> ParsedQuery:
def parse(sql: str, params: Params = None) -> ParsedQuery:
logger.debug(f"Parsing SQL: {sql!r}")
statement = sqlglot.parse_one(sql)
statement = sqlglot.parse_one(sql, dialect=SQL_DIALECT)
if isinstance(statement, WRITE_TYPES):
raise ReadOnlyError(
@@ -29,47 +38,104 @@ def parse(sql: str) -> ParsedQuery:
if not isinstance(statement, exp.Select):
raise UnsupportedQueryError(f"Only SELECT statements are supported, got: {sql!r}")
_check_joins(statement)
_check_wildcard(statement)
tables, alias_map = _extract_tables(statement)
if not tables:
raise UnsupportedQueryError("SELECT without FROM is not supported.")
table = _extract_table(statement)
columns = _extract_columns(statement)
wildcard_tables = _extract_wildcards(statement, tables, alias_map)
columns_by_table = _extract_columns(statement, tables, alias_map, wildcard_tables)
logger.debug(f"Parsed → table={table!r}, columns={columns}")
return ParsedQuery(table=table, columns=columns, original_sql=sql)
# A table that appears in FROM/JOIN but contributes no explicit column must
# still be present for the in-memory query — load it in full.
for table in tables:
if table not in wildcard_tables and not columns_by_table.get(table):
wildcard_tables.add(table)
columns_by_table.pop(table, None)
sqlite_sql = _to_sqlite(statement)
logger.debug(
f"Parsed → tables={tables}, columns={columns_by_table}, "
f"wildcard={wildcard_tables}, params={params!r}"
)
return ParsedQuery(
tables=tables,
columns_by_table=columns_by_table,
sqlite_sql=sqlite_sql,
original_sql=sql,
params=params,
wildcard_tables=wildcard_tables,
)
def _check_joins(statement: exp.Select) -> None:
if statement.find(exp.Join):
raise UnsupportedQueryError("JOIN is not supported yet. Use simple single-table SELECT.")
def _extract_tables(statement: exp.Select) -> tuple[list[str], dict[str, str]]:
"""Return real table names (first-seen order) and an alias→real-name map."""
real_names: list[str] = []
alias_map: dict[str, str] = {}
for table in statement.find_all(exp.Table):
name = table.name
if name not in real_names:
real_names.append(name)
alias_map[name] = name
if table.alias:
alias_map[table.alias] = name
return real_names, alias_map
def _check_wildcard(statement: exp.Select) -> None:
def _extract_wildcards(
statement: exp.Select, tables: list[str], alias_map: dict[str, str]
) -> set[str]:
"""Detect ``SELECT *`` (all tables) and ``alias.*`` (one table) in the projection."""
wildcard: set[str] = set()
for projection in statement.expressions:
if isinstance(projection, exp.Star):
return set(tables)
if isinstance(projection, exp.Column) and isinstance(projection.this, exp.Star):
qualifier = projection.table
wildcard.add(alias_map.get(qualifier, qualifier))
return wildcard
def _extract_columns(
statement: exp.Select,
tables: list[str],
alias_map: dict[str, str],
wildcard_tables: set[str],
) -> dict[str, list[str]]:
"""Map each table to the deduplicated columns referenced anywhere in the query."""
single = tables[0] if len(tables) == 1 else None
columns: dict[str, list[str]] = {}
seen: dict[str, set[str]] = {}
for col in statement.find_all(exp.Column):
if isinstance(col.this, exp.Star):
raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.")
if statement.find(exp.Star):
raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.")
continue
qualifier = col.table
if qualifier:
table = alias_map.get(qualifier, qualifier)
elif single is not None:
table = single
else:
raise UnsupportedQueryError(
f"Unqualified column {col.name!r} is ambiguous in a multi-table query; "
"qualify it with its table or alias."
)
if table in wildcard_tables:
continue
bucket = seen.setdefault(table, set())
if col.name not in bucket:
bucket.add(col.name)
columns.setdefault(table, []).append(col.name)
def _extract_table(statement: exp.Select) -> str:
from_clause = statement.find(exp.From)
if not from_clause:
raise UnsupportedQueryError("SELECT without FROM is not supported.")
table = from_clause.find(exp.Table)
if not table:
raise UnsupportedQueryError("Could not extract table name from query.")
return table.name
def _extract_columns(statement: exp.Select) -> list[str]:
seen: set[str] = set()
columns: list[str] = []
for col in statement.find_all(exp.Column):
name = col.name
if name not in seen:
seen.add(name)
columns.append(name)
if not columns:
raise UnsupportedQueryError("Could not extract column names from query.")
return columns
def _to_sqlite(statement: exp.Select) -> str:
"""Render the statement as SQLite SQL, stripping catalog/schema prefixes.
Mutates *statement* in place; callers must extract metadata beforehand.
"""
for table in statement.find_all(exp.Table):
table.set("db", None)
table.set("catalog", None)
return statement.sql(dialect=SQLITE_DIALECT)