| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 | from abc import ABCfrom abc import abstractmethodfrom contextlib import asynccontextmanagerfrom typing import AsyncIteratorfrom typing import Listfrom unittest import mockfrom sqlalchemy.dialects import postgresqlfrom sqlalchemy.exc import DBAPIErrorfrom sqlalchemy.ext.asyncio import AsyncConnectionfrom sqlalchemy.ext.asyncio import AsyncEnginefrom sqlalchemy.ext.asyncio import create_async_enginefrom sqlalchemy.sql import Executablefrom clean_python.base.domain.exceptions import Conflictfrom clean_python.base.infrastructure.gateway import Json__all__ = ["SQLProvider", "SQLDatabase", "FakeSQLDatabase", "assert_query_equal"]def is_serialization_error(e: DBAPIError) -> bool:    return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")class SQLProvider(ABC):    @abstractmethod    async def execute(self, query: Executable) -> List[Json]:        pass    @asynccontextmanager    async def transaction(self) -> AsyncIterator["SQLProvider"]:        raise NotImplementedError()        yieldclass SQLDatabase(SQLProvider):    engine: AsyncEngine    def __init__(self, url: str, **kwargs):        kwargs.setdefault("isolation_level", "READ COMMITTED")        self.engine = create_async_engine(url, **kwargs)    async def dispose(self) -> None:        await self.engine.dispose()    def dispose_sync(self) -> None:        self.engine.sync_engine.dispose()    async def execute(self, query: Executable) -> List[Json]:        async with self.transaction() as transaction:            return await transaction.execute(query)    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin():                yield SQLTransaction(connection)    @asynccontextmanager    async def testing_transaction(self) -> AsyncIterator[SQLProvider]:        async with self.engine.connect() as connection:            async with connection.begin() as transaction:                yield SQLTestTransaction(connection)                await transaction.rollback()class SQLTransaction(SQLProvider):    def __init__(self, connection: AsyncConnection):        self.connection = connection    async def execute(self, query: Executable) -> List[Json]:        try:            result = await self.connection.execute(query)        except DBAPIError as e:            if is_serialization_error(e):                raise Conflict(str(e))            else:                raise e        # _asdict() is a documented method of a NamedTuple        # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict        return [x._asdict() for x in result.fetchall()]class SQLTestTransaction(SQLTransaction):    @asynccontextmanager    async def transaction(self) -> AsyncIterator[SQLProvider]:        async with self.connection.begin_nested():            yield selfclass FakeSQLDatabase(SQLProvider):    def __init__(self):        self.queries: List[List[Executable]] = []        self.result = mock.Mock(return_value=[])    async def execute(self, query: Executable) -> List[Json]:        self.queries.append([query])        return self.result()    @asynccontextmanager    async def transaction(self) -> AsyncIterator["SQLProvider"]:        x = FakeSQLTransaction(result=self.result)        self.queries.append(x.queries)        yield xclass FakeSQLTransaction(SQLProvider):    def __init__(self, result: mock.Mock):        self.queries: List[Executable] = []        self.result = result    async def execute(self, query: Executable) -> List[Json]:        self.queries.append(query)        return self.result()def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):    """There are two ways of 'binding' parameters (for testing!):    literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)    literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.    """    assert isinstance(q, Executable)    compiled = q.compile(        compile_kwargs={"literal_binds": literal_binds},        dialect=postgresql.dialect(),    )    if not literal_binds:        actual = str(compiled) % compiled.params    else:        actual = str(compiled)    actual = actual.replace("\n", "").replace("  ", " ")    assert actual == expected
 |