110 lines
2.8 KiB
Python
110 lines
2.8 KiB
Python
import datetime
|
|
import decimal
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
from sqlmem._coerce import coerce_params, to_sqlite
|
|
from sqlmem.cache import CacheManager
|
|
|
|
|
|
class _FakeCursor:
|
|
def __init__(self, rows):
|
|
self._rows = rows
|
|
self.description = None
|
|
|
|
def fetchall(self):
|
|
return self._rows
|
|
|
|
|
|
class FakeSource:
|
|
"""Stand-in for a pyodbc connection that returns non-sqlite-native types."""
|
|
|
|
def __init__(self, rows):
|
|
self._rows = rows
|
|
|
|
def execute(self, sql, *args):
|
|
return _FakeCursor(self._rows)
|
|
|
|
|
|
@pytest.fixture
|
|
def cache(tmp_path):
|
|
c = CacheManager(db_path=tmp_path / "cache.db", backup_interval=9999)
|
|
yield c
|
|
c.close()
|
|
|
|
|
|
# --- to_sqlite / coerce_params unit tests -----------------------------------
|
|
|
|
|
|
def test_decimal_to_str():
|
|
assert to_sqlite(decimal.Decimal("9.99")) == "9.99"
|
|
|
|
|
|
def test_decimal_keeps_precision():
|
|
assert to_sqlite(decimal.Decimal("123456789.123456789")) == "123456789.123456789"
|
|
|
|
|
|
def test_datetime_to_iso():
|
|
assert to_sqlite(datetime.datetime(2026, 6, 1, 10, 0, 0)) == "2026-06-01T10:00:00"
|
|
|
|
|
|
def test_date_to_iso():
|
|
assert to_sqlite(datetime.date(2026, 6, 1)) == "2026-06-01"
|
|
|
|
|
|
def test_time_to_iso():
|
|
assert to_sqlite(datetime.time(10, 30, 0)) == "10:30:00"
|
|
|
|
|
|
def test_uuid_to_str():
|
|
u = uuid.uuid4()
|
|
assert to_sqlite(u) == str(u)
|
|
|
|
|
|
def test_bytearray_to_bytes():
|
|
assert to_sqlite(bytearray(b"abc")) == b"abc"
|
|
|
|
|
|
@pytest.mark.parametrize("value", [1, 1.5, "text", None, b"blob", True])
|
|
def test_native_values_pass_through(value):
|
|
assert to_sqlite(value) == value
|
|
|
|
|
|
def test_coerce_params_tuple():
|
|
assert coerce_params((decimal.Decimal("1.5"), "x")) == ("1.5", "x")
|
|
|
|
|
|
def test_coerce_params_dict():
|
|
assert coerce_params({"p": decimal.Decimal("2")}) == {"p": "2"}
|
|
|
|
|
|
def test_coerce_params_none():
|
|
assert coerce_params(None) is None
|
|
|
|
|
|
# --- integration: values reach the cache through coercion -------------------
|
|
|
|
|
|
def test_load_table_coerces_decimal_and_datetime(cache):
|
|
rows = [("1", decimal.Decimal("9.99"), datetime.datetime(2026, 6, 1, 10, 0, 0))]
|
|
cache.load_table("t", ["id", "price", "changed"], FakeSource(rows))
|
|
_, out = cache.execute_in_memory("SELECT id, price, changed FROM t")
|
|
assert out == [("1", "9.99", "2026-06-01T10:00:00")]
|
|
|
|
|
|
def test_decimal_where_param_matches_text_value(cache):
|
|
cache.load_table("t", ["price"], FakeSource([("9.99",)]))
|
|
_, out = cache.execute_in_memory(
|
|
"SELECT price FROM t WHERE price = ?", (decimal.Decimal("9.99"),)
|
|
)
|
|
assert out == [("9.99",)]
|
|
|
|
|
|
def test_upsert_rows_coerces_decimal(cache):
|
|
cache.load_table("t", ["id", "price"], FakeSource([("1", "0")]))
|
|
cache.create_unique_index("t", ["id"])
|
|
cache.upsert_rows("t", ["id", "price"], [("1", decimal.Decimal("12.50"))])
|
|
_, out = cache.execute_in_memory("SELECT price FROM t WHERE id = '1'")
|
|
assert out == [("12.50",)]
|