| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 | from abc import ABCfrom abc import abstractmethodfrom contextlib import asynccontextmanagerfrom typing import Anyfrom typing import AsyncIteratorfrom typing import Dictfrom typing import Listfrom typing import Optionalfrom typing import Sequencefrom sqlalchemy import textfrom sqlalchemy.exc import DBAPIErrorfrom sqlalchemy.ext.asyncio import AsyncConnectionfrom sqlalchemy.ext.asyncio import AsyncEnginefrom sqlalchemy.ext.asyncio import create_async_enginefrom sqlalchemy.sql import Executablefrom clean_python import Conflictfrom clean_python import Json__all__ = ["SQLProvider", "SQLDatabase"]def is_serialization_error(e: DBAPIError) -> bool:    return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")class SQLProvider(ABC):    @abstractmethod    async def execute(        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None    ) -> List[Json]:        pass    @asynccontextmanager    async def transaction(self) -> AsyncIterator["SQLProvider"]:        raise NotImplementedError()        yieldclass SQLDatabase(SQLProvider):    engine: AsyncEngine    def __init__(self, url: str, **kwargs):        kwargs.setdefault("isolation_level", "READ COMMITTED")        self.engine = create_async_engine(url, **kwargs)    async def dispose(self) -> None:        await self.engine.dispose()    def dispose_sync(self) -> None:        self.engine.sync_engine.dispose()    async def execute(        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None    ) -> List[Json]:        async with self.transaction() as transaction:            return await transaction.execute(query, bind_params)    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin():                yield SQLTransaction(connection)    @asynccontextmanager    async def testing_transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin() as transaction:                yield SQLTransaction(connection)                await transaction.rollback()    async def _execute_autocommit(self, query: Executable) -> None:        engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")        async with engine.connect() as connection:            await connection.execute(query)    async def create_database(self, name: str) -> None:        await self._execute_autocommit(text(f"CREATE DATABASE {name}"))    async def create_extension(self, name: str) -> None:        await self._execute_autocommit(text(f"CREATE EXTENSION IF NOT EXISTS {name}"))    async def drop_database(self, name: str) -> None:        await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))    async def truncate_tables(self, names: Sequence[str]) -> None:        quoted = [f'"{x}"' for x in names]        await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}"))class SQLTransaction(SQLProvider):    def __init__(self, connection: AsyncConnection):        self.connection = connection    async def execute(        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None    ) -> List[Json]:        try:            result = await self.connection.execute(query, bind_params)        except DBAPIError as e:            if is_serialization_error(e):                raise Conflict(str(e))            else:                raise e        # _asdict() is a documented method of a NamedTuple        # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict        return [x._asdict() for x in result.fetchall()]    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.connection.begin_nested():            yield self
 |