From 530c2618cfaa85660fb6dbd2d9a274fc4e383d75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Doubravsk=C3=BD?= Date: Thu, 4 Jun 2026 18:25:47 +0200 Subject: [PATCH] Add support for query parameters, JOINs, SELECT * and three-part table names --- CHANGELOG.md | 18 ++++++ README.md | 32 +++++++-- project.md | 33 ++++++---- pyproject.toml | 2 +- src/sqlmem/cache.py | 41 +++++++++--- src/sqlmem/config.py | 3 + src/sqlmem/engine.py | 9 +-- src/sqlmem/executor.py | 66 +++++++++++++------ src/sqlmem/parser.py | 144 ++++++++++++++++++++++++++++++----------- tests/test_cache.py | 1 - tests/test_config.py | 1 - tests/test_engine.py | 55 ++++++++++++++-- tests/test_executor.py | 122 ++++++++++++++++++++++++++++++++++ tests/test_parser.py | 90 +++++++++++++++++++++++--- 14 files changed, 511 insertions(+), 106 deletions(-) create mode 100644 tests/test_executor.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 15685c9..9999268 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,24 @@ All notable changes to this project will be documented in this file. --- +## [1.2.0] - 2026-06-04 + +### Added +- **Parametrized queries (R1)** — `execute(sql, params)` accepts positional (`?` tuple/list) and named (`:name` dict) parameters; passed straight to SQLite during in-memory filtering. Cache loads still fetch the full table (parameters are not applied to source fetches). +- **JOIN support (R2)** — multi-table SELECTs are parsed into per-table column sets; each table is cached independently and the JOIN runs in the in-memory SQLite. Columns in a multi-table query must be qualified by table or alias. +- **`SELECT *` support (R3)** — wildcard (and `alias.*`) queries discover all columns from the source DB, cache the whole table, and mark it `is_full` so later column queries are guaranteed cache hits without re-fetch. +- **Three-part table names (R4)** — `[catalog].[schema].[table]` is parsed to its base name for caching; the in-memory query is rewritten to strip catalog/schema prefixes so it runs under SQLite. +- `SQLMEM_SQL_DIALECT` env var (default `tsql`) — sqlglot dialect used to parse incoming SQL; T-SQL also accepts ANSI SQL and MSSQL bracket quoting. +- `CacheManager.discover_columns()` and `CacheManager.is_table_full()`; `load_table()` gained a `full` flag. + +### Changed +- `pyproject.toml` — bumped version to `1.2.0` +- `parser.py` — `ParsedQuery.table: str` replaced by `tables: list[str]` plus `columns_by_table`, `sqlite_sql`, `params`, and `wildcard_tables`; SQL is parsed with the configured dialect and rendered to SQLite for execution. +- `executor.py` — loads each referenced table independently and applies query parameters during in-memory execution. +- `cache.py` — schema version bumped to `2`; `_sqlmem_tables` gained an `is_full` column (existing on-disk caches are discarded and rebuilt on load). + +--- + ## [1.1.0] - 2026-06-03 ### Added diff --git a/README.md b/README.md index ce21d3b..46e9e32 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ Application (SQLAlchemy) On the first SELECT for a table, SQLmem fetches the required rows from the database and stores them in an in-memory SQLite instance. Subsequent queries for the same columns hit the in-memory cache with no database round-trip. When a query requests a column not yet in cache, SQLmem re-fetches the table with the expanded column set. +Parametrized queries, JOINs and `SELECT *` are all supported. Each table referenced in a JOIN is cached independently; the JOIN itself runs in the in-memory SQLite. Query parameters are applied during in-memory filtering, so cache loads always fetch the full table regardless of the `WHERE` values. + ## Installation ```bash @@ -45,9 +47,25 @@ engine = CachingEngine(base_engine) results = engine.execute("SELECT id, name FROM users WHERE status = 'active'") for row in results: print(row["id"], row["name"]) + +# Positional parameters (?): +engine.execute("SELECT id, name FROM users WHERE id = ?", ("42",)) + +# Named parameters (:name): +engine.execute("SELECT id, name FROM users WHERE id = :id", {"id": "42"}) + +# JOINs — each table is cached independently: +engine.execute( + "SELECT u.name, o.total FROM users u " + "JOIN orders o ON o.user_id = u.id WHERE u.id = ?", + ("42",), +) + +# SELECT * — loads and caches the whole table: +engine.execute("SELECT * FROM users") ``` -`execute()` returns a list of dicts. Results are compatible with standard iteration patterns. +`execute()` returns a list of dicts. Parameters are passed straight through to SQLite, so positional (`?`) and named (`:name`) styles both work. Results are compatible with standard iteration patterns. ## Cache behaviour @@ -57,10 +75,12 @@ for row in results: Query 1: SELECT a, b FROM orders → cache miss → fetch orders(a, b) from DB Query 2: SELECT a, d FROM orders → new column d → re-fetch orders(a, b, d) Query 3: SELECT b FROM orders → cache hit, no DB query -Query 4: SELECT * FROM orders → UnsupportedQueryError (wildcard not supported) -Query 5: SELECT a FROM orders JOIN … → UnsupportedQueryError (JOIN not supported) +Query 4: SELECT * FROM orders → fetches all columns, marks the table fully cached +Query 5: SELECT a FROM orders → cache hit (table already full) ``` +**`SELECT *`** loads every column and marks the table as fully cached, so any later column query is a guaranteed cache hit with no re-fetch. + **Writes are blocked** — INSERT, UPDATE, and DELETE raise `ReadOnlyError`. SQLmem is a read-only cache. ## Persistence @@ -89,13 +109,14 @@ Set via environment variables or a `.env` file: | `SQLMEM_DEBUG` | `false` | `true` enables DEBUG-level logging | | `SQLMEM_CACHE_DB` | `cache.db` | Path to the on-disk persistence file | | `SQLMEM_BACKUP_INTERVAL` | `3600` | Backup interval in seconds | +| `SQLMEM_SQL_DIALECT` | `tsql` | sqlglot dialect used to parse incoming SQL (e.g. `tsql`, `postgres`, `mysql`) | ## Exceptions | Exception | When raised | |---|---| | `ReadOnlyError` | INSERT, UPDATE, or DELETE statement | -| `UnsupportedQueryError` | `SELECT *` or any JOIN | +| `UnsupportedQueryError` | non-SELECT statement, `SELECT` without `FROM`, or an unqualified column in a multi-table query | ```python from sqlmem import ReadOnlyError, UnsupportedQueryError @@ -118,7 +139,8 @@ Set `SQLMEM_DEBUG=true` in `.env` to make the default level DEBUG when no explic ## Limitations -- `SELECT *` and JOIN queries are not supported. +- In a multi-table (JOIN) query, every column must be qualified with its table or alias; unqualified columns raise `UnsupportedQueryError`. +- Tables are keyed by their base name — two tables with the same name in different schemas share one cache entry. - No distributed cache backend (Redis etc.). - No transactional consistency guarantees. - Write operations (INSERT/UPDATE/DELETE) are always blocked. diff --git a/project.md b/project.md index 1120387..eee965c 100644 --- a/project.md +++ b/project.md @@ -59,11 +59,12 @@ with engine.connect() as conn: ## Komponenty ### 1. SQL Parser -- Detekuje typ dotazu (SELECT / zápis). -- Extrahuje názvy tabulek z FROM a JOIN klauzulí. -- Extrahuje seznam požadovaných sloupců. -- Detekuje `SELECT *` (wildcard) a JOIN — vyhodí `UnsupportedQueryError`. -- Rozhoduje, zda je dotaz obsloužitelný z cache. +- Detekuje typ dotazu (SELECT / zápis); zápisy vyhodí `ReadOnlyError`. +- Extrahuje názvy tabulek z FROM a JOIN klauzulí (podpora více tabulek). +- Mapuje požadované sloupce na tabulky přes aliasy (`columns_by_table`). +- Detekuje `SELECT *` a `alias.*` → tabulka se načte celá (`wildcard_tables`). +- Parsuje přes dialekt `SQLMEM_SQL_DIALECT` (default `tsql`) a renderuje in-memory dotaz do SQLite (stripuje catalog/schema prefixy). +- Parametry (`?` / `:name`) předává beze změny do in-memory SQLite. ### 2. Column Registry @@ -71,12 +72,12 @@ Modul se **za běhu učí**, jaké sloupce z každé tabulky aplikace potřebuje **Logika při každém příchozím dotazu:** -1. Parser detekuje `SELECT *` nebo JOIN → vyhodí `UnsupportedQueryError` (není implementováno). -2. Parser extrahuje `(tabulka, sloupce)` z dotazu. -3. Registry provede **union** nově požadovaných sloupců s již známými. -4. Cache Manager zkontroluje, zda cache pro danou tabulku obsahuje všechny potřebné sloupce: +1. Parser extrahuje `(tabulka, sloupce)` pro každou tabulku v dotazu (i přes JOIN). +2. Registry provede **union** nově požadovaných sloupců s již známými. +3. Cache Manager zkontroluje, zda cache pro danou tabulku obsahuje všechny potřebné sloupce: - **Ano** → dotaz jde přímo do SQLite RAM (cache hit). - **Ne** → re-fetch tabulky z DB s rozšířenou sadou sloupců → přepíše cache → dotaz do SQLite RAM. +4. `SELECT *` načte celou tabulku a označí ji jako `is_full` → další dotazy na libovolný sloupec jsou cache hit. **Příklad akumulace sloupců:** @@ -84,8 +85,8 @@ Modul se **za běhu učí**, jaké sloupce z každé tabulky aplikace potřebuje Dotaz 1: SELECT A, B FROM T3 → Registry: T3 = {A, B} → fetch T3(A,B) z DB Dotaz 2: SELECT A, D FROM T3 → Registry: T3 = {A, B, D} → re-fetch T3(A,B,D) z DB Dotaz 3: SELECT B FROM T3 → cache hit, žádný DB dotaz -Dotaz 4: SELECT * FROM T3 → UnsupportedQueryError (wildcard není podporován) -Dotaz 5: SELECT A FROM T3 JOIN T4 ... → UnsupportedQueryError (JOIN není podporován) +Dotaz 4: SELECT * FROM T3 → full load všech sloupců, tabulka označena is_full +Dotaz 5: SELECT A FROM T3 JOIN T4 ON … → každá tabulka cachována zvlášť, JOIN běží v RAM ``` **Metadata tabulka `_sqlmem_columns`** (uložena v SQLite): @@ -184,10 +185,16 @@ SQLMEM_DEBUG=true # DEBUG level — podrobný výpis každého dotazu, cache o --- +## Hotové funkce (dříve TODO) + +- [x] **Parametrizované dotazy**: `execute(sql, params)` — poziční `?` i pojmenované `:name`. +- [x] **Podpora `SELECT *` (wildcard)**: Načte celou tabulku do cache, označí ji jako `is_full` — další dotazy na libovolný sloupec jsou vždy cache hit bez re-fetch. +- [x] **Podpora JOIN**: Parser extrahuje sloupce z každé joinované tabulky zvlášť, Column Registry je sleduje nezávisle. Cache Manager zajistí, že všechny potřebné tabulky jsou v paměti před spuštěním dotazu. +- [x] **Třídílné názvy tabulek**: `[catalog].[schema].[table]` se cachuje pod base name, in-memory dotaz prefix stripuje. + ## TODO — budoucí funkce -- **Podpora `SELECT *` (wildcard)**: Načte celou tabulku do cache, označí ji jako `full` — další dotazy na libovolný sloupec jsou vždy cache hit bez re-fetch. -- **Podpora JOIN**: Parser extrahuje sloupce z každé joinované tabulky zvlášť, Column Registry je sleduje nezávisle. Cache Manager zajistí, že všechny potřebné tabulky jsou v paměti před spuštěním dotazu. +- **TTL na úrovni tabulky**: automatické vypršení cache po nastaveném čase. --- diff --git a/pyproject.toml b/pyproject.toml index ca54a78..f8fc535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlmem" -version = "1.1.0" +version = "1.2.0" description = "" authors = [ {name = "jan.doubravsky@gmail.com"} diff --git a/src/sqlmem/cache.py b/src/sqlmem/cache.py index 00ba0c3..7b13a7a 100644 --- a/src/sqlmem/cache.py +++ b/src/sqlmem/cache.py @@ -9,7 +9,7 @@ from loguru import logger import sqlmem._meta as _meta -SCHEMA_VERSION = 1 +SCHEMA_VERSION = 2 class CacheManager: @@ -40,7 +40,8 @@ class CacheManager: CREATE TABLE IF NOT EXISTS _sqlmem_tables ( table_name TEXT PRIMARY KEY, last_refresh_at TEXT NOT NULL, - row_count INTEGER + row_count INTEGER, + is_full INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE IF NOT EXISTS _sqlmem_columns ( table_name TEXT NOT NULL, @@ -112,17 +113,18 @@ class CacheManager: logger.info("SIGTERM received — flushing cache to disk.") self._backup_to_disk() - def mark_table_refreshed(self, table: str, row_count: int) -> None: + def mark_table_refreshed(self, table: str, row_count: int, full: bool = False) -> None: with self._lock: self._mem_conn.execute( """ - INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count) - VALUES (?, ?, ?) + INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count, is_full) + VALUES (?, ?, ?, ?) ON CONFLICT(table_name) DO UPDATE SET last_refresh_at = excluded.last_refresh_at, - row_count = excluded.row_count + row_count = excluded.row_count, + is_full = excluded.is_full """, - (table, _now(), row_count), + (table, _now(), row_count, int(full)), ) self._mem_conn.commit() @@ -132,7 +134,28 @@ class CacheManager: ).fetchone() return row is not None - def load_table(self, table: str, columns: list[str], source_conn: sqlite3.Connection) -> None: + def is_table_full(self, table: str) -> bool: + """True if the whole table (all columns) is cached — a SELECT * cache hit.""" + row = self._mem_conn.execute( + "SELECT is_full FROM _sqlmem_tables WHERE table_name = ?", (table,) + ).fetchone() + return bool(row and row[0]) + + def discover_columns(self, table: str, source_conn: sqlite3.Connection) -> list[str]: + """Return all column names of *table* from the source DB without fetching rows.""" + logger.debug(f"Discovering columns of {table!r} from source DB") + cursor = source_conn.execute(f"SELECT * FROM {table} WHERE 1 = 0") + columns = [desc[0] for desc in cursor.description] + logger.debug(f"{table!r} has columns: {columns}") + return columns + + def load_table( + self, + table: str, + columns: list[str], + source_conn: sqlite3.Connection, + full: bool = False, + ) -> None: cols = ", ".join(columns) logger.info(f"Fetching {table!r} columns [{cols}] from source DB") rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall() @@ -145,7 +168,7 @@ class CacheManager: self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", rows) self._mem_conn.commit() - self.mark_table_refreshed(table, len(rows)) + self.mark_table_refreshed(table, len(rows), full) logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})") def close(self) -> None: diff --git a/src/sqlmem/config.py b/src/sqlmem/config.py index 25fc91f..625bde4 100644 --- a/src/sqlmem/config.py +++ b/src/sqlmem/config.py @@ -9,6 +9,9 @@ load_dotenv() DEBUG = os.getenv("SQLMEM_DEBUG", "false").lower() == "true" CACHE_DB_PATH = Path(os.getenv("SQLMEM_CACHE_DB", "cache.db")) BACKUP_INTERVAL_SECONDS = int(os.getenv("SQLMEM_BACKUP_INTERVAL", "3600")) +# Dialect used by sqlglot to parse incoming SQL. Defaults to T-SQL (SQL Server), +# which also accepts ANSI SQL. In-memory queries are always rendered to SQLite. +SQL_DIALECT = os.getenv("SQLMEM_SQL_DIALECT", "tsql") # Silent by default — callers opt in via add_sink(). logger.disable("sqlmem") diff --git a/src/sqlmem/engine.py b/src/sqlmem/engine.py index d1c813c..daacfbc 100644 --- a/src/sqlmem/engine.py +++ b/src/sqlmem/engine.py @@ -1,4 +1,5 @@ import sqlite3 +from typing import cast from loguru import logger from sqlalchemy.engine import Engine @@ -6,7 +7,7 @@ from sqlalchemy.engine import Engine from .cache import CacheManager from .config import BACKUP_INTERVAL_SECONDS, CACHE_DB_PATH from .executor import QueryExecutor -from .parser import parse +from .parser import Params, parse from .registry import ColumnRegistry from .stats import Stats, StatsCollector @@ -25,10 +26,10 @@ class CachingEngine: def stats(self) -> Stats: return self._stats.snapshot(self._cache.connection) - def execute(self, sql: str) -> list[dict]: - parsed = parse(sql) + def execute(self, sql: str, params: Params = None) -> list[dict]: + parsed = parse(sql, params) with self._source_engine.connect() as sa_conn: - raw_conn: sqlite3.Connection = sa_conn.connection.dbapi_connection + raw_conn = cast(sqlite3.Connection, sa_conn.connection.dbapi_connection) executor = QueryExecutor(self._cache, self._registry, raw_conn, self._stats) return executor.execute(parsed) diff --git a/src/sqlmem/executor.py b/src/sqlmem/executor.py index 7f08b2e..b0e03a5 100644 --- a/src/sqlmem/executor.py +++ b/src/sqlmem/executor.py @@ -22,33 +22,63 @@ class QueryExecutor: self._stats = stats def execute(self, parsed: ParsedQuery) -> list[dict]: - table = parsed.table - columns = parsed.columns + for table in parsed.tables: + self._ensure_table(table, parsed) + return self._run_in_memory(parsed) + def _ensure_table(self, table: str, parsed: ParsedQuery) -> None: + if table in parsed.wildcard_tables: + self._ensure_full(table) + else: + self._ensure_columns(table, parsed.columns_by_table[table]) + + def _ensure_full(self, table: str) -> None: + """Load every column of *table* (SELECT * / t.*), refetching unless already full.""" + if self._cache.is_table_cached(table) and self._cache.is_table_full(table): + logger.debug(f"Cache hit (full): {table!r}") + self._stats.record_hit() + return + + if self._cache.is_table_cached(table): + logger.warning(f"Re-fetching {table!r} in full — SELECT * requested.") + self._stats.record_refetch() + else: + self._stats.record_miss() + + columns = self._cache.discover_columns(table, self._source_conn) + self._cache.load_table(table, columns, self._source_conn, full=True) + self._registry.update(table, columns) + + def _ensure_columns(self, table: str, columns: list[str]) -> None: + """Load *table* with at least *columns*, refetching only when columns are missing.""" missing = self._registry.needs_refetch(table, columns) table_cached = self._cache.is_table_cached(table) - if missing or not table_cached: - if table_cached and missing: - logger.warning( - f"Re-fetching {table!r} — new columns requested: {missing}. " - f"Expanding cache from {self._registry.get_columns(table)} + {missing}" - ) - self._stats.record_refetch() - else: - self._stats.record_miss() - all_columns = list(self._registry.get_columns(table)) + missing - self._cache.load_table(table, all_columns, self._source_conn) - self._registry.update(table, all_columns) - else: + if not missing and table_cached: logger.debug(f"Cache hit: {table!r} columns={columns}") self._stats.record_hit() + return - return self._run_in_memory(parsed) + if table_cached and missing: + logger.warning( + f"Re-fetching {table!r} — new columns requested: {missing}. " + f"Expanding cache from {self._registry.get_columns(table)} + {missing}" + ) + self._stats.record_refetch() + else: + self._stats.record_miss() + + all_columns = list(self._registry.get_columns(table)) + missing + self._cache.load_table(table, all_columns, self._source_conn) + self._registry.update(table, all_columns) def _run_in_memory(self, parsed: ParsedQuery) -> list[dict]: - logger.debug(f"Executing in SQLite RAM: {parsed.original_sql!r}") - cursor = self._cache.connection.execute(parsed.original_sql) + logger.debug(f"Executing in SQLite RAM: {parsed.sqlite_sql!r} params={parsed.params!r}") + conn = self._cache.connection + if parsed.params is None: + cursor = conn.execute(parsed.sqlite_sql) + else: + cursor = conn.execute(parsed.sqlite_sql, parsed.params) col_names = [desc[0] for desc in cursor.description] rows = cursor.fetchall() return [dict(zip(col_names, row)) for row in rows] diff --git a/src/sqlmem/parser.py b/src/sqlmem/parser.py index 5066fcd..475d826 100644 --- a/src/sqlmem/parser.py +++ b/src/sqlmem/parser.py @@ -1,25 +1,34 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field import sqlglot import sqlglot.expressions as exp from loguru import logger +from .config import SQL_DIALECT from .exceptions import ReadOnlyError, UnsupportedQueryError WRITE_TYPES = (exp.Insert, exp.Update, exp.Delete) +SQLITE_DIALECT = "sqlite" + +# Parameters accepted by execute(): positional (tuple/list of ``?``) or named (dict of ``:name``). +Params = tuple | list | dict | None @dataclass class ParsedQuery: - table: str - columns: list[str] + tables: list[str] + columns_by_table: dict[str, list[str]] + sqlite_sql: str original_sql: str + params: Params = None + # Tables that must be loaded in full (SELECT * / t.* / referenced without explicit columns). + wildcard_tables: set[str] = field(default_factory=set) -def parse(sql: str) -> ParsedQuery: +def parse(sql: str, params: Params = None) -> ParsedQuery: logger.debug(f"Parsing SQL: {sql!r}") - statement = sqlglot.parse_one(sql) + statement = sqlglot.parse_one(sql, dialect=SQL_DIALECT) if isinstance(statement, WRITE_TYPES): raise ReadOnlyError( @@ -29,47 +38,104 @@ def parse(sql: str) -> ParsedQuery: if not isinstance(statement, exp.Select): raise UnsupportedQueryError(f"Only SELECT statements are supported, got: {sql!r}") - _check_joins(statement) - _check_wildcard(statement) + tables, alias_map = _extract_tables(statement) + if not tables: + raise UnsupportedQueryError("SELECT without FROM is not supported.") - table = _extract_table(statement) - columns = _extract_columns(statement) + wildcard_tables = _extract_wildcards(statement, tables, alias_map) + columns_by_table = _extract_columns(statement, tables, alias_map, wildcard_tables) - logger.debug(f"Parsed → table={table!r}, columns={columns}") - return ParsedQuery(table=table, columns=columns, original_sql=sql) + # A table that appears in FROM/JOIN but contributes no explicit column must + # still be present for the in-memory query — load it in full. + for table in tables: + if table not in wildcard_tables and not columns_by_table.get(table): + wildcard_tables.add(table) + columns_by_table.pop(table, None) + + sqlite_sql = _to_sqlite(statement) + + logger.debug( + f"Parsed → tables={tables}, columns={columns_by_table}, " + f"wildcard={wildcard_tables}, params={params!r}" + ) + return ParsedQuery( + tables=tables, + columns_by_table=columns_by_table, + sqlite_sql=sqlite_sql, + original_sql=sql, + params=params, + wildcard_tables=wildcard_tables, + ) -def _check_joins(statement: exp.Select) -> None: - if statement.find(exp.Join): - raise UnsupportedQueryError("JOIN is not supported yet. Use simple single-table SELECT.") +def _extract_tables(statement: exp.Select) -> tuple[list[str], dict[str, str]]: + """Return real table names (first-seen order) and an alias→real-name map.""" + real_names: list[str] = [] + alias_map: dict[str, str] = {} + for table in statement.find_all(exp.Table): + name = table.name + if name not in real_names: + real_names.append(name) + alias_map[name] = name + if table.alias: + alias_map[table.alias] = name + return real_names, alias_map -def _check_wildcard(statement: exp.Select) -> None: +def _extract_wildcards( + statement: exp.Select, tables: list[str], alias_map: dict[str, str] +) -> set[str]: + """Detect ``SELECT *`` (all tables) and ``alias.*`` (one table) in the projection.""" + wildcard: set[str] = set() + for projection in statement.expressions: + if isinstance(projection, exp.Star): + return set(tables) + if isinstance(projection, exp.Column) and isinstance(projection.this, exp.Star): + qualifier = projection.table + wildcard.add(alias_map.get(qualifier, qualifier)) + return wildcard + + +def _extract_columns( + statement: exp.Select, + tables: list[str], + alias_map: dict[str, str], + wildcard_tables: set[str], +) -> dict[str, list[str]]: + """Map each table to the deduplicated columns referenced anywhere in the query.""" + single = tables[0] if len(tables) == 1 else None + columns: dict[str, list[str]] = {} + seen: dict[str, set[str]] = {} + for col in statement.find_all(exp.Column): if isinstance(col.this, exp.Star): - raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.") - if statement.find(exp.Star): - raise UnsupportedQueryError("SELECT * is not supported yet. Specify columns explicitly.") + continue + qualifier = col.table + if qualifier: + table = alias_map.get(qualifier, qualifier) + elif single is not None: + table = single + else: + raise UnsupportedQueryError( + f"Unqualified column {col.name!r} is ambiguous in a multi-table query; " + "qualify it with its table or alias." + ) + if table in wildcard_tables: + continue + bucket = seen.setdefault(table, set()) + if col.name not in bucket: + bucket.add(col.name) + columns.setdefault(table, []).append(col.name) - -def _extract_table(statement: exp.Select) -> str: - from_clause = statement.find(exp.From) - if not from_clause: - raise UnsupportedQueryError("SELECT without FROM is not supported.") - table = from_clause.find(exp.Table) - if not table: - raise UnsupportedQueryError("Could not extract table name from query.") - return table.name - - -def _extract_columns(statement: exp.Select) -> list[str]: - seen: set[str] = set() - columns: list[str] = [] - for col in statement.find_all(exp.Column): - name = col.name - if name not in seen: - seen.add(name) - columns.append(name) - if not columns: - raise UnsupportedQueryError("Could not extract column names from query.") return columns + + +def _to_sqlite(statement: exp.Select) -> str: + """Render the statement as SQLite SQL, stripping catalog/schema prefixes. + + Mutates *statement* in place; callers must extract metadata beforehand. + """ + for table in statement.find_all(exp.Table): + table.set("db", None) + table.set("catalog", None) + return statement.sql(dialect=SQLITE_DIALECT) diff --git a/tests/test_cache.py b/tests/test_cache.py index dba7043..63dbe7d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,4 @@ import sqlite3 -from pathlib import Path import pytest diff --git a/tests/test_config.py b/tests/test_config.py index 42cc474..2ed07db 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,5 @@ import importlib -import pytest import sqlmem.config as cfg diff --git a/tests/test_engine.py b/tests/test_engine.py index f90a9a7..301e94e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,5 +1,4 @@ import sqlite3 -from pathlib import Path import pytest from sqlalchemy import create_engine @@ -215,16 +214,60 @@ def test_delete_raises_readonly(engine): engine.execute("DELETE FROM products WHERE id = '1'") -def test_join_raises_unsupported(engine): +def test_ambiguous_unqualified_join_column_raises(engine): with pytest.raises(UnsupportedQueryError): engine.execute( - "SELECT p.name, o.qty FROM products p JOIN orders o ON p.id = o.product_id" + "SELECT name FROM products p JOIN orders o ON p.id = o.product_id" ) -def test_select_star_raises_unsupported(engine): - with pytest.raises(UnsupportedQueryError): - engine.execute("SELECT * FROM products") +# --------------------------------------------------------------------------- +# R1 — parametrized queries +# --------------------------------------------------------------------------- + +def test_positional_param(engine): + rows = engine.execute("SELECT id, name FROM products WHERE id = ?", ("1",)) + assert rows == [{"id": "1", "name": "Widget"}] + + +def test_named_param(engine): + rows = engine.execute("SELECT name FROM products WHERE id = :id", {"id": "2"}) + assert rows == [{"name": "Gadget"}] + + +# --------------------------------------------------------------------------- +# R2 — JOIN support +# --------------------------------------------------------------------------- + +def test_join_two_tables(engine): + rows = engine.execute( + "SELECT p.name, o.qty FROM products p " + "JOIN orders o ON p.id = o.product_id WHERE p.id = ?", + ("1",), + ) + assert rows == [{"name": "Widget", "qty": "2"}] + + +def test_join_caches_both_tables(engine): + engine.execute( + "SELECT p.name, o.qty FROM products p JOIN orders o ON p.id = o.product_id" + ) + assert engine._cache.is_table_cached("products") is True + assert engine._cache.is_table_cached("orders") is True + + +# --------------------------------------------------------------------------- +# R3 — SELECT * +# --------------------------------------------------------------------------- + +def test_select_star_returns_all_columns(engine): + rows = engine.execute("SELECT * FROM products WHERE id = '1'") + assert rows == [{"id": "1", "name": "Widget", "price": "9.99"}] + + +def test_select_star_marks_table_full(engine): + engine.execute("SELECT * FROM products") + assert engine._cache.is_table_full("products") is True # --------------------------------------------------------------------------- diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..04f243a --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,122 @@ +import sqlite3 + +import pytest + +from sqlmem.cache import CacheManager +from sqlmem.executor import QueryExecutor +from sqlmem.parser import parse +from sqlmem.registry import ColumnRegistry +from sqlmem.stats import StatsCollector + + +@pytest.fixture +def source_conn(): + conn = sqlite3.connect(":memory:") + conn.executescript( + """ + CREATE TABLE users (id TEXT, name TEXT, status TEXT); + INSERT INTO users VALUES ('1', 'alice', 'active'), ('2', 'bob', 'inactive'); + CREATE TABLE orders (id TEXT, user_id TEXT, total TEXT, title TEXT); + INSERT INTO orders VALUES ('10', '1', '99', 'first'), ('11', '2', '5', 'second'); + """ + ) + conn.commit() + yield conn + conn.close() + + +@pytest.fixture +def executor(tmp_path, source_conn): + cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999) + registry = ColumnRegistry(cache.connection) + stats = StatsCollector() + ex = QueryExecutor(cache, registry, source_conn, stats) + yield ex + cache.close() + + +def run(executor, sql, params=None): + return executor.execute(parse(sql, params)) + + +# --- R1: parameters --------------------------------------------------------- + + +def test_param_filters_in_memory(executor): + rows = run(executor, "SELECT id, name FROM users WHERE id = ?", ("1",)) + assert rows == [{"id": "1", "name": "alice"}] + + +def test_param_no_match(executor): + rows = run(executor, "SELECT name FROM users WHERE id = ?", ("999",)) + assert rows == [] + + +def test_named_params(executor): + rows = run(executor, "SELECT name FROM users WHERE id = :id", {"id": "2"}) + assert rows == [{"name": "bob"}] + + +# --- cache hit / miss / refetch -------------------------------------------- + + +def test_cache_hit_does_not_refetch(executor): + run(executor, "SELECT name FROM users") + run(executor, "SELECT name FROM users") + assert executor._stats.hits == 1 + assert executor._stats.misses == 1 + + +def test_new_column_triggers_refetch(executor): + run(executor, "SELECT name FROM users") + run(executor, "SELECT name, status FROM users") + assert executor._stats.misses == 1 + assert executor._stats.refetches == 1 + + +# --- R2: JOINs -------------------------------------------------------------- + + +def test_join_across_two_tables(executor): + rows = run( + executor, + "SELECT u.name, o.title FROM users u " + "JOIN orders o ON o.user_id = u.id WHERE u.id = ?", + ("1",), + ) + assert rows == [{"name": "alice", "title": "first"}] + + +def test_join_caches_each_table_independently(executor): + run( + executor, + "SELECT u.name, o.title FROM users u JOIN orders o ON o.user_id = u.id", + ) + # two distinct tables loaded → two misses + assert executor._stats.misses == 2 + assert executor._cache.is_table_cached("users") + assert executor._cache.is_table_cached("orders") + + +# --- R3: SELECT * ----------------------------------------------------------- + + +def test_select_star_returns_all_columns(executor): + rows = run(executor, "SELECT * FROM users WHERE id = ?", ("1",)) + assert rows == [{"id": "1", "name": "alice", "status": "active"}] + + +def test_select_star_marks_table_full_and_hits(executor): + run(executor, "SELECT * FROM users") + run(executor, "SELECT * FROM users") + assert executor._cache.is_table_full("users") + assert executor._stats.misses == 1 + assert executor._stats.hits == 1 + + +def test_column_query_after_star_is_a_hit(executor): + run(executor, "SELECT * FROM users") + run(executor, "SELECT name FROM users") + # full table already cached → specific column is a hit, no refetch + assert executor._stats.refetches == 0 + assert executor._stats.hits == 1 diff --git a/tests/test_parser.py b/tests/test_parser.py index 3770fc0..b6b3cfd 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -6,16 +6,22 @@ from sqlmem.parser import parse def test_simple_select(): result = parse("SELECT name, email FROM users WHERE status = 'active'") - assert result.table == "users" + assert result.tables == ["users"] + cols = result.columns_by_table["users"] # WHERE columns are also extracted — needed for in-memory SQLite filtering - assert {"name", "email"}.issubset(set(result.columns)) - assert "status" in result.columns + assert {"name", "email"}.issubset(set(cols)) + assert "status" in cols def test_multiple_columns(): result = parse("SELECT a, b, c FROM orders") - assert result.table == "orders" - assert set(result.columns) == {"a", "b", "c"} + assert result.tables == ["orders"] + assert set(result.columns_by_table["orders"]) == {"a", "b", "c"} + + +def test_columns_deduplicated_in_order(): + result = parse("SELECT a, a, b FROM t WHERE a > 1") + assert result.columns_by_table["t"] == ["a", "b"] def test_insert_raises_readonly(): @@ -33,11 +39,77 @@ def test_delete_raises_readonly(): parse("DELETE FROM users WHERE id = 1") -def test_wildcard_raises_unsupported(): +def test_select_without_from_raises(): with pytest.raises(UnsupportedQueryError): - parse("SELECT * FROM users") + parse("SELECT 1") -def test_join_raises_unsupported(): +# --- R1: parameters --------------------------------------------------------- + + +def test_params_stored(): + result = parse("SELECT name FROM users WHERE id = ?", ("7189790",)) + assert result.params == ("7189790",) + assert "?" in result.sqlite_sql + + +def test_named_params_preserved(): + result = parse("SELECT name FROM users WHERE id = :id", {"id": 1}) + assert ":id" in result.sqlite_sql + + +# --- R2: JOINs -------------------------------------------------------------- + + +def test_join_extracts_all_tables(): + result = parse( + "SELECT a.id, b.title FROM users a " + "JOIN orders b ON a.id = b.user_id WHERE a.id = ?", + (1,), + ) + assert set(result.tables) == {"users", "orders"} + assert "id" in result.columns_by_table["users"] + assert "title" in result.columns_by_table["orders"] + # join + where columns resolved to their tables via alias + assert "user_id" in result.columns_by_table["orders"] + + +def test_join_unqualified_column_is_ambiguous(): with pytest.raises(UnsupportedQueryError): - parse("SELECT a.name, b.title FROM users a JOIN orders b ON a.id = b.user_id") + parse("SELECT name FROM users a JOIN orders b ON a.id = b.user_id") + + +# --- R3: SELECT * ----------------------------------------------------------- + + +def test_wildcard_marks_table_full(): + result = parse("SELECT * FROM users") + assert result.wildcard_tables == {"users"} + assert result.columns_by_table == {} + + +def test_qualified_wildcard_marks_only_that_table(): + result = parse( + "SELECT u.*, o.total FROM users u JOIN orders o ON u.id = o.user_id" + ) + assert "users" in result.wildcard_tables + assert "orders" not in result.wildcard_tables + assert "total" in result.columns_by_table["orders"] + + +# --- R4: three-part names (MSSQL brackets) ---------------------------------- + + +def test_three_part_name_uses_base_table(): + result = parse( + "SELECT [PRODUCT_PRODUCTNR], [PRAT_NAME] " + "FROM [DP_PIM].[dbo].[VW_P_PRATVALUES] WHERE PRODUCT_PRODUCTNR = ?", + ("7189790",), + ) + assert result.tables == ["VW_P_PRATVALUES"] + cols = result.columns_by_table["VW_P_PRATVALUES"] + assert {"PRODUCT_PRODUCTNR", "PRAT_NAME"}.issubset(set(cols)) + # in-memory SQL must drop the catalog/schema prefix + assert "DP_PIM" not in result.sqlite_sql + assert "dbo" not in result.sqlite_sql + assert "VW_P_PRATVALUES" in result.sqlite_sql