Add support for query parameters, JOINs, SELECT * and three-part table names
This commit is contained in:
+32
-9
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
|
||||
import sqlmem._meta as _meta
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
SCHEMA_VERSION = 2
|
||||
|
||||
|
||||
class CacheManager:
|
||||
@@ -40,7 +40,8 @@ class CacheManager:
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_tables (
|
||||
table_name TEXT PRIMARY KEY,
|
||||
last_refresh_at TEXT NOT NULL,
|
||||
row_count INTEGER
|
||||
row_count INTEGER,
|
||||
is_full INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
|
||||
table_name TEXT NOT NULL,
|
||||
@@ -112,17 +113,18 @@ class CacheManager:
|
||||
logger.info("SIGTERM received — flushing cache to disk.")
|
||||
self._backup_to_disk()
|
||||
|
||||
def mark_table_refreshed(self, table: str, row_count: int) -> None:
|
||||
def mark_table_refreshed(self, table: str, row_count: int, full: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._mem_conn.execute(
|
||||
"""
|
||||
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count)
|
||||
VALUES (?, ?, ?)
|
||||
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count, is_full)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(table_name) DO UPDATE SET
|
||||
last_refresh_at = excluded.last_refresh_at,
|
||||
row_count = excluded.row_count
|
||||
row_count = excluded.row_count,
|
||||
is_full = excluded.is_full
|
||||
""",
|
||||
(table, _now(), row_count),
|
||||
(table, _now(), row_count, int(full)),
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
|
||||
@@ -132,7 +134,28 @@ class CacheManager:
|
||||
).fetchone()
|
||||
return row is not None
|
||||
|
||||
def load_table(self, table: str, columns: list[str], source_conn: sqlite3.Connection) -> None:
|
||||
def is_table_full(self, table: str) -> bool:
|
||||
"""True if the whole table (all columns) is cached — a SELECT * cache hit."""
|
||||
row = self._mem_conn.execute(
|
||||
"SELECT is_full FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
||||
).fetchone()
|
||||
return bool(row and row[0])
|
||||
|
||||
def discover_columns(self, table: str, source_conn: sqlite3.Connection) -> list[str]:
|
||||
"""Return all column names of *table* from the source DB without fetching rows."""
|
||||
logger.debug(f"Discovering columns of {table!r} from source DB")
|
||||
cursor = source_conn.execute(f"SELECT * FROM {table} WHERE 1 = 0")
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
logger.debug(f"{table!r} has columns: {columns}")
|
||||
return columns
|
||||
|
||||
def load_table(
|
||||
self,
|
||||
table: str,
|
||||
columns: list[str],
|
||||
source_conn: sqlite3.Connection,
|
||||
full: bool = False,
|
||||
) -> None:
|
||||
cols = ", ".join(columns)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
|
||||
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
|
||||
@@ -145,7 +168,7 @@ class CacheManager:
|
||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", rows)
|
||||
self._mem_conn.commit()
|
||||
|
||||
self.mark_table_refreshed(table, len(rows))
|
||||
self.mark_table_refreshed(table, len(rows), full)
|
||||
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
|
||||
|
||||
def close(self) -> None:
|
||||
|
||||
@@ -9,6 +9,9 @@ load_dotenv()
|
||||
DEBUG = os.getenv("SQLMEM_DEBUG", "false").lower() == "true"
|
||||
CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db"))
|
||||
BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600"))
|
||||
# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server),
|
||||
# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite.
|
||||
SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql")
|
||||
|
||||
# Silent by default — callers opt in via add_sink().
|
||||
logger.disable("sqlmem")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sqlite3
|
||||
from typing import cast
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy.engine import Engine
|
||||
@@ -6,7 +7,7 @@ from sqlalchemy.engine import Engine
|
||||
from .cache import CacheManager
|
||||
from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH
|
||||
from .executor import QueryExecutor
|
||||
from .parser import parse
|
||||
from .parser import Params, parse
|
||||
from .registry import ColumnRegistry
|
||||
from .stats import Stats, StatsCollector
|
||||
|
||||
@@ -25,10 +26,10 @@ class CachingEngine:
|
||||
def stats(self) -> Stats:
|
||||
return self._stats.snapshot(self._cache.connection)
|
||||
|
||||
def execute(self, sql: str) -> list[dict]:
|
||||
parsed = parse(sql)
|
||||
def execute(self, sql: str, params: Params = None) -> list[dict]:
|
||||
parsed = parse(sql, params)
|
||||
with self._source_engine.connect() as sa_conn:
|
||||
raw_conn: sqlite3.Connection = sa_conn.connection.dbapi_connection
|
||||
raw_conn = cast(sqlite3.Connection, sa_conn.connection.dbapi_connection)
|
||||
executor = QueryExecutor(self._cache, self._registry, raw_conn, self._stats)
|
||||
return executor.execute(parsed)
|
||||
|
||||
|
||||
+48
-18
@@ -22,33 +22,63 @@ class QueryExecutor:
|
||||
self._stats = stats
|
||||
|
||||
def execute(self, parsed: ParsedQuery) -> list[dict]:
|
||||
table = parsed.table
|
||||
columns = parsed.columns
|
||||
for table in parsed.tables:
|
||||
self._ensure_table(table, parsed)
|
||||
return self._run_in_memory(parsed)
|
||||
|
||||
def _ensure_table(self, table: str, parsed: ParsedQuery) -> None:
|
||||
if table in parsed.wildcard_tables:
|
||||
self._ensure_full(table)
|
||||
else:
|
||||
self._ensure_columns(table, parsed.columns_by_table[table])
|
||||
|
||||
def _ensure_full(self, table: str) -> None:
|
||||
"""Load every column of *table* (SELECT * / t.*), refetching unless already full."""
|
||||
if self._cache.is_table_cached(table) and self._cache.is_table_full(table):
|
||||
logger.debug(f"Cache hit (full): {table!r}")
|
||||
self._stats.record_hit()
|
||||
return
|
||||
|
||||
if self._cache.is_table_cached(table):
|
||||
logger.warning(f"Re-fetching {table!r} in full — SELECT * requested.")
|
||||
self._stats.record_refetch()
|
||||
else:
|
||||
self._stats.record_miss()
|
||||
|
||||
columns = self._cache.discover_columns(table, self._source_conn)
|
||||
self._cache.load_table(table, columns, self._source_conn, full=True)
|
||||
self._registry.update(table, columns)
|
||||
|
||||
def _ensure_columns(self, table: str, columns: list[str]) -> None:
|
||||
"""Load *table* with at least *columns*, refetching only when columns are missing."""
|
||||
missing = self._registry.needs_refetch(table, columns)
|
||||
table_cached = self._cache.is_table_cached(table)
|
||||
|
||||
if missing or not table_cached:
|
||||
if table_cached and missing:
|
||||
logger.warning(
|
||||
f"Re-fetching {table!r} — new columns requested: {missing}. "
|
||||
f"Expanding cache from {self._registry.get_columns(table)} + {missing}"
|
||||
)
|
||||
self._stats.record_refetch()
|
||||
else:
|
||||
self._stats.record_miss()
|
||||
all_columns = list(self._registry.get_columns(table)) + missing
|
||||
self._cache.load_table(table, all_columns, self._source_conn)
|
||||
self._registry.update(table, all_columns)
|
||||
else:
|
||||
if not missing and table_cached:
|
||||
logger.debug(f"Cache hit: {table!r} columns={columns}")
|
||||
self._stats.record_hit()
|
||||
return
|
||||
|
||||
return self._run_in_memory(parsed)
|
||||
if table_cached and missing:
|
||||
logger.warning(
|
||||
f"Re-fetching {table!r} — new columns requested: {missing}. "
|
||||
f"Expanding cache from {self._registry.get_columns(table)} + {missing}"
|
||||
)
|
||||
self._stats.record_refetch()
|
||||
else:
|
||||
self._stats.record_miss()
|
||||
|
||||
all_columns = list(self._registry.get_columns(table)) + missing
|
||||
self._cache.load_table(table, all_columns, self._source_conn)
|
||||
self._registry.update(table, all_columns)
|
||||
|
||||
def _run_in_memory(self, parsed: ParsedQuery) -> list[dict]:
|
||||
logger.debug(f"Executing in SQLite RAM: {parsed.original_sql!r}")
|
||||
cursor = self._cache.connection.execute(parsed.original_sql)
|
||||
logger.debug(f"Executing in SQLite RAM: {parsed.sqlite_sql!r} params={parsed.params!r}")
|
||||
conn = self._cache.connection
|
||||
if parsed.params is None:
|
||||
cursor = conn.execute(parsed.sqlite_sql)
|
||||
else:
|
||||
cursor = conn.execute(parsed.sqlite_sql, parsed.params)
|
||||
col_names = [desc[0] for desc in cursor.description]
|
||||
rows = cursor.fetchall()
|
||||
return [dict(zip(col_names, row)) for row in rows]
|
||||
|
||||
+105
-39
@@ -1,25 +1,34 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import sqlglot
|
||||
import sqlglot.expressions as exp
|
||||
from loguru import logger
|
||||
|
||||
from .config import SQL_DIALECT
|
||||
from .exceptions import ReadOnlyError, UnsupportedQueryError
|
||||
|
||||
WRITE_TYPES = (exp.Insert, exp.Update, exp.Delete)
|
||||
SQLITE_DIALECT = "sqlite"
|
||||
|
||||
# Parameters accepted by execute(): positional (tuple/list of ``?``) or named (dict of ``:name``).
|
||||
Params = tuple | list | dict | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedQuery:
|
||||
table: str
|
||||
columns: list[str]
|
||||
tables: list[str]
|
||||
columns_by_table: dict[str, list[str]]
|
||||
sqlite_sql: str
|
||||
original_sql: str
|
||||
params: Params = None
|
||||
# Tables that must be loaded in full (SELECT * / t.* / referenced without explicit columns).
|
||||
wildcard_tables: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
def parse(sql: str) -> ParsedQuery:
|
||||
def parse(sql: str, params: Params = None) -> ParsedQuery:
|
||||
logger.debug(f"Parsing SQL: {sql!r}")
|
||||
|
||||
statement = sqlglot.parse_one(sql)
|
||||
statement = sqlglot.parse_one(sql, dialect=SQL_DIALECT)
|
||||
|
||||
if isinstance(statement, WRITE_TYPES):
|
||||
raise ReadOnlyError(
|
||||
@@ -29,47 +38,104 @@ def parse(sql: str) -> ParsedQuery:
|
||||
if not isinstance(statement, exp.Select):
|
||||
raise UnsupportedQueryError(f"Only SELECT statements are supported, got: {sql!r}")
|
||||
|
||||
_check_joins(statement)
|
||||
_check_wildcard(statement)
|
||||
tables, alias_map = _extract_tables(statement)
|
||||
if not tables:
|
||||
raise UnsupportedQueryError("SELECT without FROM is not supported.")
|
||||
|
||||
table = _extract_table(statement)
|
||||
columns = _extract_columns(statement)
|
||||
wildcard_tables = _extract_wildcards(statement, tables, alias_map)
|
||||
columns_by_table = _extract_columns(statement, tables, alias_map, wildcard_tables)
|
||||
|
||||
logger.debug(f"Parsed → table={table!r}, columns={columns}")
|
||||
return ParsedQuery(table=table, columns=columns, original_sql=sql)
|
||||
# A table that appears in FROM/JOIN but contributes no explicit column must
|
||||
# still be present for the in-memory query — load it in full.
|
||||
for table in tables:
|
||||
if table not in wildcard_tables and not columns_by_table.get(table):
|
||||
wildcard_tables.add(table)
|
||||
columns_by_table.pop(table, None)
|
||||
|
||||
sqlite_sql = _to_sqlite(statement)
|
||||
|
||||
logger.debug(
|
||||
f"Parsed → tables={tables}, columns={columns_by_table}, "
|
||||
f"wildcard={wildcard_tables}, params={params!r}"
|
||||
)
|
||||
return ParsedQuery(
|
||||
tables=tables,
|
||||
columns_by_table=columns_by_table,
|
||||
sqlite_sql=sqlite_sql,
|
||||
original_sql=sql,
|
||||
params=params,
|
||||
wildcard_tables=wildcard_tables,
|
||||
)
|
||||
|
||||
|
||||
def _check_joins(statement: exp.Select) -> None:
|
||||
if statement.find(exp.Join):
|
||||
raise UnsupportedQueryError("JOIN is not supported yet. Use simple single-table SELECT.")
|
||||
def _extract_tables(statement: exp.Select) -> tuple[list[str], dict[str, str]]:
|
||||
"""Return real table names (first-seen order) and an alias→real-name map."""
|
||||
real_names: list[str] = []
|
||||
alias_map: dict[str, str] = {}
|
||||
for table in statement.find_all(exp.Table):
|
||||
name = table.name
|
||||
if name not in real_names:
|
||||
real_names.append(name)
|
||||
alias_map[name] = name
|
||||
if table.alias:
|
||||
alias_map[table.alias] = name
|
||||
return real_names, alias_map
|
||||
|
||||
|
||||
def _check_wildcard(statement: exp.Select) -> None:
|
||||
def _extract_wildcards(
|
||||
statement: exp.Select, tables: list[str], alias_map: dict[str, str]
|
||||
) -> set[str]:
|
||||
"""Detect ``SELECT *`` (all tables) and ``alias.*`` (one table) in the projection."""
|
||||
wildcard: set[str] = set()
|
||||
for projection in statement.expressions:
|
||||
if isinstance(projection, exp.Star):
|
||||
return set(tables)
|
||||
if isinstance(projection, exp.Column) and isinstance(projection.this, exp.Star):
|
||||
qualifier = projection.table
|
||||
wildcard.add(alias_map.get(qualifier, qualifier))
|
||||
return wildcard
|
||||
|
||||
|
||||
def _extract_columns(
|
||||
statement: exp.Select,
|
||||
tables: list[str],
|
||||
alias_map: dict[str, str],
|
||||
wildcard_tables: set[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Map each table to the deduplicated columns referenced anywhere in the query."""
|
||||
single = tables[0] if len(tables) == 1 else None
|
||||
columns: dict[str, list[str]] = {}
|
||||
seen: dict[str, set[str]] = {}
|
||||
|
||||
for col in statement.find_all(exp.Column):
|
||||
if isinstance(col.this, exp.Star):
|
||||
raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.")
|
||||
if statement.find(exp.Star):
|
||||
raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.")
|
||||
continue
|
||||
qualifier = col.table
|
||||
if qualifier:
|
||||
table = alias_map.get(qualifier, qualifier)
|
||||
elif single is not None:
|
||||
table = single
|
||||
else:
|
||||
raise UnsupportedQueryError(
|
||||
f"Unqualified column {col.name!r} is ambiguous in a multi-table query; "
|
||||
"qualify it with its table or alias."
|
||||
)
|
||||
if table in wildcard_tables:
|
||||
continue
|
||||
bucket = seen.setdefault(table, set())
|
||||
if col.name not in bucket:
|
||||
bucket.add(col.name)
|
||||
columns.setdefault(table, []).append(col.name)
|
||||
|
||||
|
||||
def _extract_table(statement: exp.Select) -> str:
|
||||
from_clause = statement.find(exp.From)
|
||||
if not from_clause:
|
||||
raise UnsupportedQueryError("SELECT without FROM is not supported.")
|
||||
table = from_clause.find(exp.Table)
|
||||
if not table:
|
||||
raise UnsupportedQueryError("Could not extract table name from query.")
|
||||
return table.name
|
||||
|
||||
|
||||
def _extract_columns(statement: exp.Select) -> list[str]:
|
||||
seen: set[str] = set()
|
||||
columns: list[str] = []
|
||||
for col in statement.find_all(exp.Column):
|
||||
name = col.name
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
columns.append(name)
|
||||
if not columns:
|
||||
raise UnsupportedQueryError("Could not extract column names from query.")
|
||||
return columns
|
||||
|
||||
|
||||
def _to_sqlite(statement: exp.Select) -> str:
|
||||
"""Render the statement as SQLite SQL, stripping catalog/schema prefixes.
|
||||
|
||||
Mutates *statement* in place; callers must extract metadata beforehand.
|
||||
"""
|
||||
for table in statement.find_all(exp.Table):
|
||||
table.set("db", None)
|
||||
table.set("catalog", None)
|
||||
return statement.sql(dialect=SQLITE_DIALECT)
|
||||
|
||||
Reference in New Issue
Block a user