sql_provider.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from contextlib import asynccontextmanager
  4. from typing import Any
  5. from typing import AsyncIterator
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from typing import Sequence
  10. from sqlalchemy import text
  11. from sqlalchemy.exc import DBAPIError
  12. from sqlalchemy.ext.asyncio import AsyncConnection
  13. from sqlalchemy.ext.asyncio import AsyncEngine
  14. from sqlalchemy.ext.asyncio import create_async_engine
  15. from sqlalchemy.sql import Executable
  16. from clean_python import Conflict
  17. from clean_python import Json
  18. __all__ = ["SQLProvider", "SQLDatabase"]
  19. def is_serialization_error(e: DBAPIError) -> bool:
  20. return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
  21. class SQLProvider(ABC):
  22. @abstractmethod
  23. async def execute(
  24. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  25. ) -> List[Json]:
  26. pass
  27. @asynccontextmanager
  28. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  29. raise NotImplementedError()
  30. yield
  31. class SQLDatabase(SQLProvider):
  32. engine: AsyncEngine
  33. def __init__(self, url: str, **kwargs):
  34. kwargs.setdefault("isolation_level", "READ COMMITTED")
  35. self.engine = create_async_engine(url, **kwargs)
  36. async def dispose(self) -> None:
  37. await self.engine.dispose()
  38. def dispose_sync(self) -> None:
  39. self.engine.sync_engine.dispose()
  40. async def execute(
  41. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  42. ) -> List[Json]:
  43. async with self.transaction() as transaction:
  44. return await transaction.execute(query, bind_params)
  45. @asynccontextmanager
  46. async def transaction(self) -> AsyncIterator[SQLProvider]:
  47. async with self.engine.connect() as connection:
  48. async with connection.begin():
  49. yield SQLTransaction(connection)
  50. @asynccontextmanager
  51. async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
  52. async with self.engine.connect() as connection:
  53. async with connection.begin() as transaction:
  54. yield SQLTransaction(connection)
  55. await transaction.rollback()
  56. async def _execute_autocommit(self, query: Executable) -> None:
  57. engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
  58. async with engine.connect() as connection:
  59. await connection.execute(query)
  60. async def create_database(self, name: str) -> None:
  61. await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
  62. async def create_extension(self, name: str) -> None:
  63. await self._execute_autocommit(text(f"CREATE EXTENSION IF NOT EXISTS {name}"))
  64. async def drop_database(self, name: str) -> None:
  65. await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
  66. async def truncate_tables(self, names: Sequence[str]) -> None:
  67. quoted = [f'"{x}"' for x in names]
  68. await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}"))
  69. class SQLTransaction(SQLProvider):
  70. def __init__(self, connection: AsyncConnection):
  71. self.connection = connection
  72. async def execute(
  73. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  74. ) -> List[Json]:
  75. try:
  76. result = await self.connection.execute(query, bind_params)
  77. except DBAPIError as e:
  78. if is_serialization_error(e):
  79. raise Conflict(str(e))
  80. else:
  81. raise e
  82. # _asdict() is a documented method of a NamedTuple
  83. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
  84. return [x._asdict() for x in result.fetchall()]
  85. @asynccontextmanager
  86. async def transaction(self) -> AsyncIterator[SQLProvider]:
  87. async with self.connection.begin_nested():
  88. yield self