12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from contextlib import asynccontextmanager
- from typing import Any
- from typing import AsyncIterator
- from typing import Dict
- from typing import List
- from typing import Optional
- from unittest import mock
- from sqlalchemy.dialects import postgresql
- from sqlalchemy.sql import Executable
- from clean_python import Json
- from clean_python.sql import SQLProvider
- __all__ = ["FakeSQLDatabase", "assert_query_equal"]
- class FakeSQLDatabase(SQLProvider):
- def __init__(self):
- self.queries: List[List[Executable]] = []
- self.result = mock.Mock(return_value=[])
- async def execute(
- self, query: Executable, _: Optional[Dict[str, Any]] = None
- ) -> List[Json]:
- self.queries.append([query])
- return self.result()
- @asynccontextmanager
- async def transaction(self) -> AsyncIterator["SQLProvider"]:
- x = FakeSQLTransaction(result=self.result)
- self.queries.append(x.queries)
- yield x
- class FakeSQLTransaction(SQLProvider):
- def __init__(self, result: mock.Mock):
- self.queries: List[Executable] = []
- self.result = result
- async def execute(
- self, query: Executable, _: Optional[Dict[str, Any]] = None
- ) -> List[Json]:
- self.queries.append(query)
- return self.result()
- def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):
- """There are two ways of 'binding' parameters (for testing!):
- literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)
- literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.
- """
- assert isinstance(q, Executable)
- compiled = q.compile(
- compile_kwargs={"literal_binds": literal_binds},
- dialect=postgresql.dialect(),
- )
- if not literal_binds:
- actual = str(compiled) % compiled.params
- else:
- actual = str(compiled)
- actual = actual.replace("\n", "").replace(" ", " ")
- assert actual == expected
|