testing.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from contextlib import asynccontextmanager
  2. from typing import AsyncIterator
  3. from typing import List
  4. from unittest import mock
  5. from sqlalchemy.dialects import postgresql
  6. from sqlalchemy.sql import Executable
  7. from clean_python import Json
  8. from clean_python.sql import SQLProvider
  9. __all__ = ["FakeSQLDatabase", "assert_query_equal"]
  10. class FakeSQLDatabase(SQLProvider):
  11. def __init__(self):
  12. self.queries: List[List[Executable]] = []
  13. self.result = mock.Mock(return_value=[])
  14. async def execute(self, query: Executable) -> List[Json]:
  15. self.queries.append([query])
  16. return self.result()
  17. @asynccontextmanager
  18. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  19. x = FakeSQLTransaction(result=self.result)
  20. self.queries.append(x.queries)
  21. yield x
  22. class FakeSQLTransaction(SQLProvider):
  23. def __init__(self, result: mock.Mock):
  24. self.queries: List[Executable] = []
  25. self.result = result
  26. async def execute(self, query: Executable) -> List[Json]:
  27. self.queries.append(query)
  28. return self.result()
  29. def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):
  30. """There are two ways of 'binding' parameters (for testing!):
  31. literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)
  32. literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.
  33. """
  34. assert isinstance(q, Executable)
  35. compiled = q.compile(
  36. compile_kwargs={"literal_binds": literal_binds},
  37. dialect=postgresql.dialect(),
  38. )
  39. if not literal_binds:
  40. actual = str(compiled) % compiled.params
  41. else:
  42. actual = str(compiled)
  43. actual = actual.replace("\n", "").replace(" ", " ")
  44. assert actual == expected