| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 | from abc import ABC, abstractmethodfrom contextlib import asynccontextmanagerfrom typing import AsyncIterator, Listfrom unittest import mockfrom sqlalchemy.dialects import postgresqlfrom sqlalchemy.exc import DBAPIErrorfrom sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_enginefrom sqlalchemy.sql import Executablefrom .exceptions import Conflictfrom .gateway import Json__all__ = ["SQLProvider", "SQLDatabase", "FakeSQLDatabase", "assert_query_equal"]def is_serialization_error(e: DBAPIError) -> bool:    return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")class SQLProvider(ABC):    @abstractmethod    async def execute(self, query: Executable) -> List[Json]:        pass    @asynccontextmanager    async def transaction(self) -> AsyncIterator["SQLProvider"]:        raise NotImplementedError()        yieldclass SQLDatabase(SQLProvider):    engine: AsyncEngine    def __init__(self, url: str, **kwargs):        kwargs.setdefault("isolation_level", "READ COMMITTED")        self.engine = create_async_engine(url, **kwargs)    async def dispose(self) -> None:        await self.engine.dispose()    def dispose_sync(self) -> None:        self.engine.sync_engine.dispose()    async def execute(self, query: Executable) -> List[Json]:        async with self.transaction() as transaction:            return await transaction.execute(query)    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin():                yield SQLTransaction(connection)    @asynccontextmanager    async def testing_transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin() as transaction:                yield SQLTestTransaction(connection)                await transaction.rollback()class SQLTransaction(SQLProvider):    def __init__(self, connection: AsyncConnection):        self.connection = connection    async def execute(self, query: Executable) -> List[Json]:        try:            result = await self.connection.execute(query)        except DBAPIError as e:            if is_serialization_error(e):                raise Conflict(str(e))            else:                raise e        # _asdict() is a documented method of a NamedTuple        # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict        return [x._asdict() for x in result.fetchall()]class SQLTestTransaction(SQLTransaction):    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.connection.begin_nested():            yield selfclass FakeSQLDatabase(SQLProvider):    def __init__(self):        self.queries: List[List[Executable]] = []        self.result = mock.Mock(return_value=[])    async def execute(self, query: Executable) -> List[Json]:        self.queries.append([query])        return self.result()    @asynccontextmanager    async def transaction(self) -> AsyncIterator["SQLProvider"]:        x = FakeSQLTransaction(result=self.result)        self.queries.append(x.queries)        yield xclass FakeSQLTransaction(SQLProvider):    def __init__(self, result: mock.Mock):        self.queries: List[Executable] = []        self.result = result    async def execute(self, query: Executable) -> List[Json]:        self.queries.append(query)        return self.result()def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):    """There are two ways of 'binding' parameters (for testing!):    literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)    literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.    """    assert isinstance(q, Executable)    compiled = q.compile(        compile_kwargs={"literal_binds": literal_binds},        dialect=postgresql.dialect(),    )    if not literal_binds:        actual = str(compiled) % compiled.params    else:        actual = str(compiled)    actual = actual.replace("\n", "").replace("  ", " ")    assert actual == expected
 |