| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 | import refrom abc import ABCfrom abc import abstractmethodfrom contextlib import asynccontextmanagerfrom typing import Anyfrom typing import AsyncIteratorfrom typing import Dictfrom typing import Listfrom typing import Optionalfrom typing import Sequencefrom sqlalchemy import textfrom sqlalchemy.exc import DBAPIErrorfrom sqlalchemy.ext.asyncio import AsyncConnectionfrom sqlalchemy.ext.asyncio import AsyncEnginefrom sqlalchemy.ext.asyncio import create_async_enginefrom sqlalchemy.sql import Executablefrom clean_python import AlreadyExistsfrom clean_python import Conflictfrom clean_python import Json__all__ = ["SQLProvider", "SQLDatabase"]UNIQUE_VIOLATION_DETAIL_REGEX = re.compile(    r"DETAIL:\s*Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists")def maybe_raise_conflict(e: DBAPIError) -> None:    # https://www.postgresql.org/docs/current/errcodes-appendix.html    if e.orig.pgcode == "40001":  # serialization_failure        raise Conflict("could not execute query due to concurrent update")def maybe_raise_already_exists(e: DBAPIError) -> None:    # https://www.postgresql.org/docs/current/errcodes-appendix.html    if e.orig.pgcode == "23505":  # unique_violation        match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.orig.args[0].split("\n")[-1])        if match:            raise AlreadyExists(key=match["key"], value=match["value"])        else:            raise AlreadyExists()class SQLProvider(ABC):    @abstractmethod    async def execute(        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None    ) -> List[Json]:        pass    @asynccontextmanager    async def transaction(self) -> AsyncIterator["SQLProvider"]:        raise NotImplementedError()        yieldclass SQLDatabase(SQLProvider):    engine: AsyncEngine    def __init__(self, url: str, **kwargs):        kwargs.setdefault("isolation_level", "REPEATABLE READ")        self.engine = create_async_engine(url, **kwargs)    async def dispose(self) -> None:        await self.engine.dispose()    def dispose_sync(self) -> None:        self.engine.sync_engine.dispose()    async def execute(        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None    ) -> List[Json]:        async with self.transaction() as transaction:            return await transaction.execute(query, bind_params)    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin():                yield SQLTransaction(connection)    @asynccontextmanager    async def testing_transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin() as transaction:                yield SQLTransaction(connection)                await transaction.rollback()    async def _execute_autocommit(self, query: Executable) -> None:        engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")        async with engine.connect() as connection:            await connection.execute(query)    async def create_database(self, name: str) -> None:        await self._execute_autocommit(text(f"CREATE DATABASE {name}"))    async def create_extension(self, name: str) -> None:        await self._execute_autocommit(text(f"CREATE EXTENSION IF NOT EXISTS {name}"))    async def drop_database(self, name: str) -> None:        await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))    async def truncate_tables(self, names: Sequence[str]) -> None:        quoted = [f'"{x}"' for x in names]        await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}"))class SQLTransaction(SQLProvider):    def __init__(self, connection: AsyncConnection):        self.connection = connection    async def execute(        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None    ) -> List[Json]:        try:            result = await self.connection.execute(query, bind_params)        except DBAPIError as e:            maybe_raise_conflict(e)            maybe_raise_already_exists(e)            raise e        # _asdict() is a documented method of a NamedTuple        # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict        return [x._asdict() for x in result.fetchall()]    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.connection.begin_nested():            yield self
 |