sql_provider.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from contextlib import asynccontextmanager
  4. from typing import Any
  5. from typing import AsyncIterator
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from sqlalchemy import text
  10. from sqlalchemy.exc import DBAPIError
  11. from sqlalchemy.ext.asyncio import AsyncConnection
  12. from sqlalchemy.ext.asyncio import AsyncEngine
  13. from sqlalchemy.ext.asyncio import create_async_engine
  14. from sqlalchemy.sql import Executable
  15. from clean_python import Conflict
  16. from clean_python import Json
  17. __all__ = ["SQLProvider", "SQLDatabase"]
  18. def is_serialization_error(e: DBAPIError) -> bool:
  19. return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
  20. class SQLProvider(ABC):
  21. @abstractmethod
  22. async def execute(
  23. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  24. ) -> List[Json]:
  25. pass
  26. @asynccontextmanager
  27. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  28. raise NotImplementedError()
  29. yield
  30. class SQLDatabase(SQLProvider):
  31. engine: AsyncEngine
  32. def __init__(self, url: str, **kwargs):
  33. kwargs.setdefault("isolation_level", "READ COMMITTED")
  34. self.engine = create_async_engine(url, **kwargs)
  35. async def dispose(self) -> None:
  36. await self.engine.dispose()
  37. def dispose_sync(self) -> None:
  38. self.engine.sync_engine.dispose()
  39. async def execute(
  40. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  41. ) -> List[Json]:
  42. async with self.transaction() as transaction:
  43. return await transaction.execute(query, bind_params)
  44. @asynccontextmanager
  45. async def transaction(self) -> AsyncIterator[SQLProvider]:
  46. async with self.engine.connect() as connection:
  47. async with connection.begin():
  48. yield SQLTransaction(connection)
  49. @asynccontextmanager
  50. async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
  51. async with self.engine.connect() as connection:
  52. async with connection.begin() as transaction:
  53. yield SQLTransaction(connection)
  54. await transaction.rollback()
  55. async def _execute_autocommit(self, query: Executable) -> None:
  56. engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
  57. async with engine.connect() as connection:
  58. await connection.execute(query)
  59. async def create_database(self, name: str) -> None:
  60. await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
  61. async def create_extension(self, name: str) -> None:
  62. await self._execute_autocommit(text(f"CREATE EXTENSION IF NOT EXISTS {name}"))
  63. async def drop_database(self, name: str) -> None:
  64. await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
  65. class SQLTransaction(SQLProvider):
  66. def __init__(self, connection: AsyncConnection):
  67. self.connection = connection
  68. async def execute(
  69. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  70. ) -> List[Json]:
  71. try:
  72. result = await self.connection.execute(query, bind_params)
  73. except DBAPIError as e:
  74. if is_serialization_error(e):
  75. raise Conflict(str(e))
  76. else:
  77. raise e
  78. # _asdict() is a documented method of a NamedTuple
  79. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
  80. return [x._asdict() for x in result.fetchall()]
  81. @asynccontextmanager
  82. async def transaction(self) -> AsyncIterator[SQLProvider]:
  83. async with self.connection.begin_nested():
  84. yield self