sql_provider.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from contextlib import asynccontextmanager
  4. from typing import Any
  5. from typing import AsyncIterator
  6. from typing import Dict
  7. from typing import List
  8. from sqlalchemy import text
  9. from sqlalchemy.exc import DBAPIError
  10. from sqlalchemy.ext.asyncio import AsyncConnection
  11. from sqlalchemy.ext.asyncio import AsyncEngine
  12. from sqlalchemy.ext.asyncio import create_async_engine
  13. from sqlalchemy.sql import Executable
  14. from clean_python import Conflict
  15. from clean_python import Json
  16. __all__ = ["SQLProvider", "SQLDatabase"]
  17. def is_serialization_error(e: DBAPIError) -> bool:
  18. return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
  19. class SQLProvider(ABC):
  20. @abstractmethod
  21. async def execute(
  22. self, query: Executable, bind_params: Dict[str, Any] = None
  23. ) -> List[Json]:
  24. pass
  25. @asynccontextmanager
  26. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  27. raise NotImplementedError()
  28. yield
  29. class SQLDatabase(SQLProvider):
  30. engine: AsyncEngine
  31. def __init__(self, url: str, **kwargs):
  32. kwargs.setdefault("isolation_level", "READ COMMITTED")
  33. self.engine = create_async_engine(url, **kwargs)
  34. async def dispose(self) -> None:
  35. await self.engine.dispose()
  36. def dispose_sync(self) -> None:
  37. self.engine.sync_engine.dispose()
  38. async def execute(
  39. self, query: Executable, bind_params: Dict[str, Any] = None
  40. ) -> List[Json]:
  41. async with self.transaction() as transaction:
  42. return await transaction.execute(query, bind_params)
  43. @asynccontextmanager
  44. async def transaction(self) -> AsyncIterator[SQLProvider]:
  45. async with self.engine.connect() as connection:
  46. async with connection.begin():
  47. yield SQLTransaction(connection)
  48. @asynccontextmanager
  49. async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
  50. async with self.engine.connect() as connection:
  51. async with connection.begin() as transaction:
  52. yield SQLTransaction(connection)
  53. await transaction.rollback()
  54. async def _execute_autocommit(self, query: Executable) -> None:
  55. engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
  56. async with engine.connect() as connection:
  57. await connection.execute(query)
  58. async def create_database(self, name: str) -> None:
  59. await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
  60. async def create_extension(self, name: str) -> None:
  61. await self._execute_autocommit(text(f"CREATE EXTENSION IF NOT EXISTS {name}"))
  62. async def drop_database(self, name: str) -> None:
  63. await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
  64. class SQLTransaction(SQLProvider):
  65. def __init__(self, connection: AsyncConnection):
  66. self.connection = connection
  67. async def execute(
  68. self, query: Executable, bind_params: Dict[str, Any] = None
  69. ) -> List[Json]:
  70. try:
  71. result = await self.connection.execute(query, bind_params)
  72. except DBAPIError as e:
  73. if is_serialization_error(e):
  74. raise Conflict(str(e))
  75. else:
  76. raise e
  77. # _asdict() is a documented method of a NamedTuple
  78. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
  79. return [x._asdict() for x in result.fetchall()]
  80. @asynccontextmanager
  81. async def transaction(self) -> AsyncIterator[SQLProvider]:
  82. async with self.connection.begin_nested():
  83. yield self