sql_provider.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from contextlib import asynccontextmanager
  4. from typing import AsyncIterator
  5. from typing import List
  6. from sqlalchemy.exc import DBAPIError
  7. from sqlalchemy.ext.asyncio import AsyncConnection
  8. from sqlalchemy.ext.asyncio import AsyncEngine
  9. from sqlalchemy.ext.asyncio import create_async_engine
  10. from sqlalchemy.sql import Executable
  11. from clean_python import Conflict
  12. from clean_python import Json
  13. __all__ = ["SQLProvider", "SQLDatabase"]
  14. def is_serialization_error(e: DBAPIError) -> bool:
  15. return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
  16. class SQLProvider(ABC):
  17. @abstractmethod
  18. async def execute(self, query: Executable) -> List[Json]:
  19. pass
  20. @asynccontextmanager
  21. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  22. raise NotImplementedError()
  23. yield
  24. class SQLDatabase(SQLProvider):
  25. engine: AsyncEngine
  26. def __init__(self, url: str, **kwargs):
  27. kwargs.setdefault("isolation_level", "READ COMMITTED")
  28. self.engine = create_async_engine(url, **kwargs)
  29. async def dispose(self) -> None:
  30. await self.engine.dispose()
  31. def dispose_sync(self) -> None:
  32. self.engine.sync_engine.dispose()
  33. async def execute(self, query: Executable) -> List[Json]:
  34. async with self.transaction() as transaction:
  35. return await transaction.execute(query)
  36. @asynccontextmanager
  37. async def transaction(self) -> AsyncIterator[SQLProvider]:
  38. async with self.engine.connect() as connection:
  39. async with connection.begin():
  40. yield SQLTransaction(connection)
  41. class SQLTransaction(SQLProvider):
  42. def __init__(self, connection: AsyncConnection):
  43. self.connection = connection
  44. async def execute(self, query: Executable) -> List[Json]:
  45. try:
  46. result = await self.connection.execute(query)
  47. except DBAPIError as e:
  48. if is_serialization_error(e):
  49. raise Conflict(str(e))
  50. else:
  51. raise e
  52. # _asdict() is a documented method of a NamedTuple
  53. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
  54. return [x._asdict() for x in result.fetchall()]
  55. @asynccontextmanager
  56. async def transaction(self) -> AsyncIterator[SQLProvider]:
  57. async with self.connection.begin_nested():
  58. yield self