| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 | from abc import ABCfrom abc import abstractmethodfrom contextlib import asynccontextmanagerfrom typing import AsyncIteratorfrom typing import Listfrom sqlalchemy import textfrom sqlalchemy.exc import DBAPIErrorfrom sqlalchemy.ext.asyncio import AsyncConnectionfrom sqlalchemy.ext.asyncio import AsyncEnginefrom sqlalchemy.ext.asyncio import create_async_enginefrom sqlalchemy.sql import Executablefrom clean_python import Conflictfrom clean_python import Json__all__ = ["SQLProvider", "SQLDatabase"]def is_serialization_error(e: DBAPIError) -> bool:    return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")class SQLProvider(ABC):    @abstractmethod    async def execute(self, query: Executable) -> List[Json]:        pass    @asynccontextmanager    async def transaction(self) -> AsyncIterator["SQLProvider"]:        raise NotImplementedError()        yieldclass SQLDatabase(SQLProvider):    engine: AsyncEngine    def __init__(self, url: str, **kwargs):        kwargs.setdefault("isolation_level", "READ COMMITTED")        self.engine = create_async_engine(url, **kwargs)    async def dispose(self) -> None:        await self.engine.dispose()    def dispose_sync(self) -> None:        self.engine.sync_engine.dispose()    async def execute(self, query: Executable) -> List[Json]:        async with self.transaction() as transaction:            return await transaction.execute(query)    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin():                yield SQLTransaction(connection)    @asynccontextmanager    async def testing_transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin() as transaction:                yield SQLTransaction(connection)                await transaction.rollback()    async def _execute_autocommit(self, query: Executable) -> None:        engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")        async with engine.connect() as connection:            await connection.execute(query)    async def create_database(self, name: str) -> None:        await self._execute_autocommit(text(f"CREATE DATABASE {name}"))    async def drop_database(self, name: str) -> None:        await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))class SQLTransaction(SQLProvider):    def __init__(self, connection: AsyncConnection):        self.connection = connection    async def execute(self, query: Executable) -> List[Json]:        try:            result = await self.connection.execute(query)        except DBAPIError as e:            if is_serialization_error(e):                raise Conflict(str(e))            else:                raise e        # _asdict() is a documented method of a NamedTuple        # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict        return [x._asdict() for x in result.fetchall()]    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.connection.begin_nested():            yield self
 |