from contextlib import asynccontextmanager from typing import AsyncIterator from typing import List from unittest import mock from sqlalchemy.dialects import postgresql from sqlalchemy.sql import Executable from clean_python import Json from clean_python.sql import SQLProvider __all__ = ["FakeSQLDatabase", "assert_query_equal"] class FakeSQLDatabase(SQLProvider): def __init__(self): self.queries: List[List[Executable]] = [] self.result = mock.Mock(return_value=[]) async def execute(self, query: Executable) -> List[Json]: self.queries.append([query]) return self.result() @asynccontextmanager async def transaction(self) -> AsyncIterator["SQLProvider"]: x = FakeSQLTransaction(result=self.result) self.queries.append(x.queries) yield x class FakeSQLTransaction(SQLProvider): def __init__(self, result: mock.Mock): self.queries: List[Executable] = [] self.result = result async def execute(self, query: Executable) -> List[Json]: self.queries.append(query) return self.result() def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True): """There are two ways of 'binding' parameters (for testing!): literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range) literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes. """ assert isinstance(q, Executable) compiled = q.compile( compile_kwargs={"literal_binds": literal_binds}, dialect=postgresql.dialect(), ) if not literal_binds: actual = str(compiled) % compiled.params else: actual = str(compiled) actual = actual.replace("\n", "").replace(" ", " ") assert actual == expected