sql_provider.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from contextlib import asynccontextmanager
  4. from typing import AsyncIterator
  5. from typing import List
  6. from unittest import mock
  7. from sqlalchemy.dialects import postgresql
  8. from sqlalchemy.exc import DBAPIError
  9. from sqlalchemy.ext.asyncio import AsyncConnection
  10. from sqlalchemy.ext.asyncio import AsyncEngine
  11. from sqlalchemy.ext.asyncio import create_async_engine
  12. from sqlalchemy.sql import Executable
  13. from clean_python.base.domain.exceptions import Conflict
  14. from clean_python.base.infrastructure.gateway import Json
  15. __all__ = ["SQLProvider", "SQLDatabase", "FakeSQLDatabase", "assert_query_equal"]
  16. def is_serialization_error(e: DBAPIError) -> bool:
  17. return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
  18. class SQLProvider(ABC):
  19. @abstractmethod
  20. async def execute(self, query: Executable) -> List[Json]:
  21. pass
  22. @asynccontextmanager
  23. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  24. raise NotImplementedError()
  25. yield
  26. class SQLDatabase(SQLProvider):
  27. engine: AsyncEngine
  28. def __init__(self, url: str, **kwargs):
  29. kwargs.setdefault("isolation_level", "READ COMMITTED")
  30. self.engine = create_async_engine(url, **kwargs)
  31. async def dispose(self) -> None:
  32. await self.engine.dispose()
  33. def dispose_sync(self) -> None:
  34. self.engine.sync_engine.dispose()
  35. async def execute(self, query: Executable) -> List[Json]:
  36. async with self.transaction() as transaction:
  37. return await transaction.execute(query)
  38. @asynccontextmanager
  39. async def transaction(self) -> AsyncIterator[SQLProvider]:
  40. async with self.engine.connect() as connection:
  41. async with connection.begin():
  42. yield SQLTransaction(connection)
  43. @asynccontextmanager
  44. async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
  45. async with self.engine.connect() as connection:
  46. async with connection.begin() as transaction:
  47. yield SQLTestTransaction(connection)
  48. await transaction.rollback()
  49. class SQLTransaction(SQLProvider):
  50. def __init__(self, connection: AsyncConnection):
  51. self.connection = connection
  52. async def execute(self, query: Executable) -> List[Json]:
  53. try:
  54. result = await self.connection.execute(query)
  55. except DBAPIError as e:
  56. if is_serialization_error(e):
  57. raise Conflict(str(e))
  58. else:
  59. raise e
  60. # _asdict() is a documented method of a NamedTuple
  61. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
  62. return [x._asdict() for x in result.fetchall()]
  63. class SQLTestTransaction(SQLTransaction):
  64. @asynccontextmanager
  65. async def transaction(self) -> AsyncIterator[SQLProvider]:
  66. async with self.connection.begin_nested():
  67. yield self
  68. class FakeSQLDatabase(SQLProvider):
  69. def __init__(self):
  70. self.queries: List[List[Executable]] = []
  71. self.result = mock.Mock(return_value=[])
  72. async def execute(self, query: Executable) -> List[Json]:
  73. self.queries.append([query])
  74. return self.result()
  75. @asynccontextmanager
  76. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  77. x = FakeSQLTransaction(result=self.result)
  78. self.queries.append(x.queries)
  79. yield x
  80. class FakeSQLTransaction(SQLProvider):
  81. def __init__(self, result: mock.Mock):
  82. self.queries: List[Executable] = []
  83. self.result = result
  84. async def execute(self, query: Executable) -> List[Json]:
  85. self.queries.append(query)
  86. return self.result()
  87. def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):
  88. """There are two ways of 'binding' parameters (for testing!):
  89. literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)
  90. literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.
  91. """
  92. assert isinstance(q, Executable)
  93. compiled = q.compile(
  94. compile_kwargs={"literal_binds": literal_binds},
  95. dialect=postgresql.dialect(),
  96. )
  97. if not literal_binds:
  98. actual = str(compiled) % compiled.params
  99. else:
  100. actual = str(compiled)
  101. actual = actual.replace("\n", "").replace(" ", " ")
  102. assert actual == expected