testing.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from contextlib import asynccontextmanager
  2. from typing import Any
  3. from typing import AsyncIterator
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from unittest import mock
  8. from sqlalchemy.dialects import postgresql
  9. from sqlalchemy.sql import Executable
  10. from clean_python import Json
  11. from clean_python.sql import SQLProvider
  12. __all__ = ["FakeSQLDatabase", "assert_query_equal"]
  13. class FakeSQLDatabase(SQLProvider):
  14. def __init__(self):
  15. self.queries: List[List[Executable]] = []
  16. self.result = mock.Mock(return_value=[])
  17. async def execute(
  18. self, query: Executable, _: Optional[Dict[str, Any]] = None
  19. ) -> List[Json]:
  20. self.queries.append([query])
  21. return self.result()
  22. @asynccontextmanager
  23. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  24. x = FakeSQLTransaction(result=self.result)
  25. self.queries.append(x.queries)
  26. yield x
  27. class FakeSQLTransaction(SQLProvider):
  28. def __init__(self, result: mock.Mock):
  29. self.queries: List[Executable] = []
  30. self.result = result
  31. async def execute(
  32. self, query: Executable, _: Optional[Dict[str, Any]] = None
  33. ) -> List[Json]:
  34. self.queries.append(query)
  35. return self.result()
  36. def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):
  37. """There are two ways of 'binding' parameters (for testing!):
  38. literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)
  39. literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.
  40. """
  41. assert isinstance(q, Executable)
  42. compiled = q.compile(
  43. compile_kwargs={"literal_binds": literal_binds},
  44. dialect=postgresql.dialect(),
  45. )
  46. if not literal_binds:
  47. actual = str(compiled) % compiled.params
  48. else:
  49. actual = str(compiled)
  50. actual = actual.replace("\n", "").replace(" ", " ")
  51. assert actual == expected