Add support for query parameters, JOINs, SELECT * and three-part table names
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
import sqlmem.config as cfg
|
||||
|
||||
|
||||
+49
-6
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
@@ -215,16 +214,60 @@ def test_delete_raises_readonly(engine):
|
||||
engine.execute("DELETE FROM products WHERE id = '1'")
|
||||
|
||||
|
||||
def test_join_raises_unsupported(engine):
|
||||
def test_ambiguous_unqualified_join_column_raises(engine):
|
||||
with pytest.raises(UnsupportedQueryError):
|
||||
engine.execute(
|
||||
"SELECT p.name, o.qty FROM products p JOIN orders o ON p.id = o.product_id"
|
||||
"SELECT name FROM products p JOIN orders o ON p.id = o.product_id"
|
||||
)
|
||||
|
||||
|
||||
def test_select_star_raises_unsupported(engine):
|
||||
with pytest.raises(UnsupportedQueryError):
|
||||
engine.execute("SELECT * FROM products")
|
||||
# ---------------------------------------------------------------------------
|
||||
# R1 — parametrized queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_positional_param(engine):
|
||||
rows = engine.execute("SELECT id, name FROM products WHERE id = ?", ("1",))
|
||||
assert rows == [{"id": "1", "name": "Widget"}]
|
||||
|
||||
|
||||
def test_named_param(engine):
|
||||
rows = engine.execute("SELECT name FROM products WHERE id = :id", {"id": "2"})
|
||||
assert rows == [{"name": "Gadget"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# R2 — JOIN support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_join_two_tables(engine):
|
||||
rows = engine.execute(
|
||||
"SELECT p.name, o.qty FROM products p "
|
||||
"JOIN orders o ON p.id = o.product_id WHERE p.id = ?",
|
||||
("1",),
|
||||
)
|
||||
assert rows == [{"name": "Widget", "qty": "2"}]
|
||||
|
||||
|
||||
def test_join_caches_both_tables(engine):
|
||||
engine.execute(
|
||||
"SELECT p.name, o.qty FROM products p JOIN orders o ON p.id = o.product_id"
|
||||
)
|
||||
assert engine._cache.is_table_cached("products") is True
|
||||
assert engine._cache.is_table_cached("orders") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# R3 — SELECT *
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_select_star_returns_all_columns(engine):
|
||||
rows = engine.execute("SELECT * FROM products WHERE id = '1'")
|
||||
assert rows == [{"id": "1", "name": "Widget", "price": "9.99"}]
|
||||
|
||||
|
||||
def test_select_star_marks_table_full(engine):
|
||||
engine.execute("SELECT * FROM products")
|
||||
assert engine._cache.is_table_full("products") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import sqlite3
|
||||
|
||||
import pytest
|
||||
|
||||
from sqlmem.cache import CacheManager
|
||||
from sqlmem.executor import QueryExecutor
|
||||
from sqlmem.parser import parse
|
||||
from sqlmem.registry import ColumnRegistry
|
||||
from sqlmem.stats import StatsCollector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def source_conn():
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE users (id TEXT, name TEXT, status TEXT);
|
||||
INSERT INTO users VALUES ('1', 'alice', 'active'), ('2', 'bob', 'inactive');
|
||||
CREATE TABLE orders (id TEXT, user_id TEXT, total TEXT, title TEXT);
|
||||
INSERT INTO orders VALUES ('10', '1', '99', 'first'), ('11', '2', '5', 'second');
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor(tmp_path, source_conn):
|
||||
cache = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
|
||||
registry = ColumnRegistry(cache.connection)
|
||||
stats = StatsCollector()
|
||||
ex = QueryExecutor(cache, registry, source_conn, stats)
|
||||
yield ex
|
||||
cache.close()
|
||||
|
||||
|
||||
def run(executor, sql, params=None):
|
||||
return executor.execute(parse(sql, params))
|
||||
|
||||
|
||||
# --- R1: parameters ---------------------------------------------------------
|
||||
|
||||
|
||||
def test_param_filters_in_memory(executor):
|
||||
rows = run(executor, "SELECT id, name FROM users WHERE id = ?", ("1",))
|
||||
assert rows == [{"id": "1", "name": "alice"}]
|
||||
|
||||
|
||||
def test_param_no_match(executor):
|
||||
rows = run(executor, "SELECT name FROM users WHERE id = ?", ("999",))
|
||||
assert rows == []
|
||||
|
||||
|
||||
def test_named_params(executor):
|
||||
rows = run(executor, "SELECT name FROM users WHERE id = :id", {"id": "2"})
|
||||
assert rows == [{"name": "bob"}]
|
||||
|
||||
|
||||
# --- cache hit / miss / refetch --------------------------------------------
|
||||
|
||||
|
||||
def test_cache_hit_does_not_refetch(executor):
|
||||
run(executor, "SELECT name FROM users")
|
||||
run(executor, "SELECT name FROM users")
|
||||
assert executor._stats.hits == 1
|
||||
assert executor._stats.misses == 1
|
||||
|
||||
|
||||
def test_new_column_triggers_refetch(executor):
|
||||
run(executor, "SELECT name FROM users")
|
||||
run(executor, "SELECT name, status FROM users")
|
||||
assert executor._stats.misses == 1
|
||||
assert executor._stats.refetches == 1
|
||||
|
||||
|
||||
# --- R2: JOINs --------------------------------------------------------------
|
||||
|
||||
|
||||
def test_join_across_two_tables(executor):
|
||||
rows = run(
|
||||
executor,
|
||||
"SELECT u.name, o.title FROM users u "
|
||||
"JOIN orders o ON o.user_id = u.id WHERE u.id = ?",
|
||||
("1",),
|
||||
)
|
||||
assert rows == [{"name": "alice", "title": "first"}]
|
||||
|
||||
|
||||
def test_join_caches_each_table_independently(executor):
|
||||
run(
|
||||
executor,
|
||||
"SELECT u.name, o.title FROM users u JOIN orders o ON o.user_id = u.id",
|
||||
)
|
||||
# two distinct tables loaded → two misses
|
||||
assert executor._stats.misses == 2
|
||||
assert executor._cache.is_table_cached("users")
|
||||
assert executor._cache.is_table_cached("orders")
|
||||
|
||||
|
||||
# --- R3: SELECT * -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_select_star_returns_all_columns(executor):
|
||||
rows = run(executor, "SELECT * FROM users WHERE id = ?", ("1",))
|
||||
assert rows == [{"id": "1", "name": "alice", "status": "active"}]
|
||||
|
||||
|
||||
def test_select_star_marks_table_full_and_hits(executor):
|
||||
run(executor, "SELECT * FROM users")
|
||||
run(executor, "SELECT * FROM users")
|
||||
assert executor._cache.is_table_full("users")
|
||||
assert executor._stats.misses == 1
|
||||
assert executor._stats.hits == 1
|
||||
|
||||
|
||||
def test_column_query_after_star_is_a_hit(executor):
|
||||
run(executor, "SELECT * FROM users")
|
||||
run(executor, "SELECT name FROM users")
|
||||
# full table already cached → specific column is a hit, no refetch
|
||||
assert executor._stats.refetches == 0
|
||||
assert executor._stats.hits == 1
|
||||
+81
-9
@@ -6,16 +6,22 @@ from sqlmem.parser import parse
|
||||
|
||||
def test_simple_select():
|
||||
result = parse("SELECT name, email FROM users WHERE status = 'active'")
|
||||
assert result.table == "users"
|
||||
assert result.tables == ["users"]
|
||||
cols = result.columns_by_table["users"]
|
||||
# WHERE columns are also extracted — needed for in-memory SQLite filtering
|
||||
assert {"name", "email"}.issubset(set(result.columns))
|
||||
assert "status" in result.columns
|
||||
assert {"name", "email"}.issubset(set(cols))
|
||||
assert "status" in cols
|
||||
|
||||
|
||||
def test_multiple_columns():
|
||||
result = parse("SELECT a, b, c FROM orders")
|
||||
assert result.table == "orders"
|
||||
assert set(result.columns) == {"a", "b", "c"}
|
||||
assert result.tables == ["orders"]
|
||||
assert set(result.columns_by_table["orders"]) == {"a", "b", "c"}
|
||||
|
||||
|
||||
def test_columns_deduplicated_in_order():
|
||||
result = parse("SELECT a, a, b FROM t WHERE a > 1")
|
||||
assert result.columns_by_table["t"] == ["a", "b"]
|
||||
|
||||
|
||||
def test_insert_raises_readonly():
|
||||
@@ -33,11 +39,77 @@ def test_delete_raises_readonly():
|
||||
parse("DELETE FROM users WHERE id = 1")
|
||||
|
||||
|
||||
def test_wildcard_raises_unsupported():
|
||||
def test_select_without_from_raises():
|
||||
with pytest.raises(UnsupportedQueryError):
|
||||
parse("SELECT * FROM users")
|
||||
parse("SELECT 1")
|
||||
|
||||
|
||||
def test_join_raises_unsupported():
|
||||
# --- R1: parameters ---------------------------------------------------------
|
||||
|
||||
|
||||
def test_params_stored():
|
||||
result = parse("SELECT name FROM users WHERE id = ?", ("7189790",))
|
||||
assert result.params == ("7189790",)
|
||||
assert "?" in result.sqlite_sql
|
||||
|
||||
|
||||
def test_named_params_preserved():
|
||||
result = parse("SELECT name FROM users WHERE id = :id", {"id": 1})
|
||||
assert ":id" in result.sqlite_sql
|
||||
|
||||
|
||||
# --- R2: JOINs --------------------------------------------------------------
|
||||
|
||||
|
||||
def test_join_extracts_all_tables():
|
||||
result = parse(
|
||||
"SELECT a.id, b.title FROM users a "
|
||||
"JOIN orders b ON a.id = b.user_id WHERE a.id = ?",
|
||||
(1,),
|
||||
)
|
||||
assert set(result.tables) == {"users", "orders"}
|
||||
assert "id" in result.columns_by_table["users"]
|
||||
assert "title" in result.columns_by_table["orders"]
|
||||
# join + where columns resolved to their tables via alias
|
||||
assert "user_id" in result.columns_by_table["orders"]
|
||||
|
||||
|
||||
def test_join_unqualified_column_is_ambiguous():
|
||||
with pytest.raises(UnsupportedQueryError):
|
||||
parse("SELECT a.name, b.title FROM users a JOIN orders b ON a.id = b.user_id")
|
||||
parse("SELECT name FROM users a JOIN orders b ON a.id = b.user_id")
|
||||
|
||||
|
||||
# --- R3: SELECT * -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_wildcard_marks_table_full():
|
||||
result = parse("SELECT * FROM users")
|
||||
assert result.wildcard_tables == {"users"}
|
||||
assert result.columns_by_table == {}
|
||||
|
||||
|
||||
def test_qualified_wildcard_marks_only_that_table():
|
||||
result = parse(
|
||||
"SELECT u.*, o.total FROM users u JOIN orders o ON u.id = o.user_id"
|
||||
)
|
||||
assert "users" in result.wildcard_tables
|
||||
assert "orders" not in result.wildcard_tables
|
||||
assert "total" in result.columns_by_table["orders"]
|
||||
|
||||
|
||||
# --- R4: three-part names (MSSQL brackets) ----------------------------------
|
||||
|
||||
|
||||
def test_three_part_name_uses_base_table():
|
||||
result = parse(
|
||||
"SELECT [PRODUCT_PRODUCTNR], [PRAT_NAME] "
|
||||
"FROM [DP_PIM].[dbo].[VW_P_PRATVALUES] WHERE PRODUCT_PRODUCTNR = ?",
|
||||
("7189790",),
|
||||
)
|
||||
assert result.tables == ["VW_P_PRATVALUES"]
|
||||
cols = result.columns_by_table["VW_P_PRATVALUES"]
|
||||
assert {"PRODUCT_PRODUCTNR", "PRAT_NAME"}.issubset(set(cols))
|
||||
# in-memory SQL must drop the catalog/schema prefix
|
||||
assert "DP_PIM" not in result.sqlite_sql
|
||||
assert "dbo" not in result.sqlite_sql
|
||||
assert "VW_P_PRATVALUES" in result.sqlite_sql
|
||||
|
||||
Reference in New Issue
Block a user