sql_provider.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from contextlib import asynccontextmanager
  4. from typing import AsyncIterator
  5. from typing import List
  6. from sqlalchemy import text
  7. from sqlalchemy.exc import DBAPIError
  8. from sqlalchemy.ext.asyncio import AsyncConnection
  9. from sqlalchemy.ext.asyncio import AsyncEngine
  10. from sqlalchemy.ext.asyncio import create_async_engine
  11. from sqlalchemy.sql import Executable
  12. from clean_python import Conflict
  13. from clean_python import Json
  14. __all__ = ["SQLProvider", "SQLDatabase"]
  15. def is_serialization_error(e: DBAPIError) -> bool:
  16. return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
  17. class SQLProvider(ABC):
  18. @abstractmethod
  19. async def execute(self, query: Executable) -> List[Json]:
  20. pass
  21. @asynccontextmanager
  22. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  23. raise NotImplementedError()
  24. yield
  25. class SQLDatabase(SQLProvider):
  26. engine: AsyncEngine
  27. def __init__(self, url: str, **kwargs):
  28. kwargs.setdefault("isolation_level", "READ COMMITTED")
  29. self.engine = create_async_engine(url, **kwargs)
  30. async def dispose(self) -> None:
  31. await self.engine.dispose()
  32. def dispose_sync(self) -> None:
  33. self.engine.sync_engine.dispose()
  34. async def execute(self, query: Executable) -> List[Json]:
  35. async with self.transaction() as transaction:
  36. return await transaction.execute(query)
  37. @asynccontextmanager
  38. async def transaction(self) -> AsyncIterator[SQLProvider]:
  39. async with self.engine.connect() as connection:
  40. async with connection.begin():
  41. yield SQLTransaction(connection)
  42. @asynccontextmanager
  43. async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
  44. async with self.engine.connect() as connection:
  45. async with connection.begin() as transaction:
  46. yield SQLTransaction(connection)
  47. await transaction.rollback()
  48. async def _execute_autocommit(self, query: Executable) -> None:
  49. engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
  50. async with engine.connect() as connection:
  51. await connection.execute(query)
  52. async def create_database(self, name: str) -> None:
  53. await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
  54. async def drop_database(self, name: str) -> None:
  55. await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
  56. class SQLTransaction(SQLProvider):
  57. def __init__(self, connection: AsyncConnection):
  58. self.connection = connection
  59. async def execute(self, query: Executable) -> List[Json]:
  60. try:
  61. result = await self.connection.execute(query)
  62. except DBAPIError as e:
  63. if is_serialization_error(e):
  64. raise Conflict(str(e))
  65. else:
  66. raise e
  67. # _asdict() is a documented method of a NamedTuple
  68. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
  69. return [x._asdict() for x in result.fetchall()]
  70. @asynccontextmanager
  71. async def transaction(self) -> AsyncIterator[SQLProvider]:
  72. async with self.connection.begin_nested():
  73. yield self