Add support for query parameters, JOINs, SELECT * and three-part table names
This commit is contained in:
+32
-9
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
|
||||
import sqlmem._meta as _meta
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
SCHEMA_VERSION = 2
|
||||
|
||||
|
||||
class CacheManager:
|
||||
@@ -40,7 +40,8 @@ class CacheManager:
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_tables (
|
||||
table_name TEXT PRIMARY KEY,
|
||||
last_refresh_at TEXT NOT NULL,
|
||||
row_count INTEGER
|
||||
row_count INTEGER,
|
||||
is_full INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS _sqlmem_columns (
|
||||
table_name TEXT NOT NULL,
|
||||
@@ -112,17 +113,18 @@ class CacheManager:
|
||||
logger.info("SIGTERM received — flushing cache to disk.")
|
||||
self._backup_to_disk()
|
||||
|
||||
def mark_table_refreshed(self, table: str, row_count: int) -> None:
|
||||
def mark_table_refreshed(self, table: str, row_count: int, full: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._mem_conn.execute(
|
||||
"""
|
||||
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count)
|
||||
VALUES (?, ?, ?)
|
||||
INSERT INTO _sqlmem_tables (table_name, last_refresh_at, row_count, is_full)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(table_name) DO UPDATE SET
|
||||
last_refresh_at = excluded.last_refresh_at,
|
||||
row_count = excluded.row_count
|
||||
row_count = excluded.row_count,
|
||||
is_full = excluded.is_full
|
||||
""",
|
||||
(table, _now(), row_count),
|
||||
(table, _now(), row_count, int(full)),
|
||||
)
|
||||
self._mem_conn.commit()
|
||||
|
||||
@@ -132,7 +134,28 @@ class CacheManager:
|
||||
).fetchone()
|
||||
return row is not None
|
||||
|
||||
def load_table(self, table: str, columns: list[str], source_conn: sqlite3.Connection) -> None:
|
||||
def is_table_full(self, table: str) -> bool:
|
||||
"""True if the whole table (all columns) is cached — a SELECT * cache hit."""
|
||||
row = self._mem_conn.execute(
|
||||
"SELECT is_full FROM _sqlmem_tables WHERE table_name = ?", (table,)
|
||||
).fetchone()
|
||||
return bool(row and row[0])
|
||||
|
||||
def discover_columns(self, table: str, source_conn: sqlite3.Connection) -> list[str]:
|
||||
"""Return all column names of *table* from the source DB without fetching rows."""
|
||||
logger.debug(f"Discovering columns of {table!r} from source DB")
|
||||
cursor = source_conn.execute(f"SELECT * FROM {table} WHERE 1 = 0")
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
logger.debug(f"{table!r} has columns: {columns}")
|
||||
return columns
|
||||
|
||||
def load_table(
|
||||
self,
|
||||
table: str,
|
||||
columns: list[str],
|
||||
source_conn: sqlite3.Connection,
|
||||
full: bool = False,
|
||||
) -> None:
|
||||
cols = ", ".join(columns)
|
||||
logger.info(f"Fetching {table!r} columns [{cols}] from source DB")
|
||||
rows = source_conn.execute(f"SELECT {cols} FROM {table}").fetchall()
|
||||
@@ -145,7 +168,7 @@ class CacheManager:
|
||||
self._mem_conn.executemany(f"INSERT INTO {table} VALUES ({placeholders})", rows)
|
||||
self._mem_conn.commit()
|
||||
|
||||
self.mark_table_refreshed(table, len(rows))
|
||||
self.mark_table_refreshed(table, len(rows), full)
|
||||
logger.info(f"Table {table!r} cached ({len(rows)} rows, columns: {columns})")
|
||||
|
||||
def close(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user