|
@@ -3,19 +3,17 @@ from abc import abstractmethod
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncIterator
|
|
from typing import AsyncIterator
|
|
from typing import List
|
|
from typing import List
|
|
-from unittest import mock
|
|
|
|
|
|
|
|
-from sqlalchemy.dialects import postgresql
|
|
|
|
from sqlalchemy.exc import DBAPIError
|
|
from sqlalchemy.exc import DBAPIError
|
|
from sqlalchemy.ext.asyncio import AsyncConnection
|
|
from sqlalchemy.ext.asyncio import AsyncConnection
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.sql import Executable
|
|
from sqlalchemy.sql import Executable
|
|
|
|
|
|
-from clean_python.base.domain.exceptions import Conflict
|
|
|
|
-from clean_python.base.infrastructure.gateway import Json
|
|
|
|
|
|
+from clean_python import Conflict
|
|
|
|
+from clean_python import Json
|
|
|
|
|
|
-__all__ = ["SQLProvider", "SQLDatabase", "FakeSQLDatabase", "assert_query_equal"]
|
|
|
|
|
|
+__all__ = ["SQLProvider", "SQLDatabase"]
|
|
|
|
|
|
|
|
|
|
def is_serialization_error(e: DBAPIError) -> bool:
|
|
def is_serialization_error(e: DBAPIError) -> bool:
|
|
@@ -56,13 +54,6 @@ class SQLDatabase(SQLProvider):
|
|
async with connection.begin():
|
|
async with connection.begin():
|
|
yield SQLTransaction(connection)
|
|
yield SQLTransaction(connection)
|
|
|
|
|
|
- @asynccontextmanager
|
|
|
|
- async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
|
|
|
|
- async with self.engine.connect() as connection:
|
|
|
|
- async with connection.begin() as transaction:
|
|
|
|
- yield SQLTestTransaction(connection)
|
|
|
|
- await transaction.rollback()
|
|
|
|
-
|
|
|
|
|
|
|
|
class SQLTransaction(SQLProvider):
|
|
class SQLTransaction(SQLProvider):
|
|
def __init__(self, connection: AsyncConnection):
|
|
def __init__(self, connection: AsyncConnection):
|
|
@@ -80,54 +71,7 @@ class SQLTransaction(SQLProvider):
|
|
# https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
|
|
# https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
|
|
return [x._asdict() for x in result.fetchall()]
|
|
return [x._asdict() for x in result.fetchall()]
|
|
|
|
|
|
-
|
|
|
|
-class SQLTestTransaction(SQLTransaction):
|
|
|
|
@asynccontextmanager
|
|
@asynccontextmanager
|
|
async def transaction(self) -> AsyncIterator[SQLProvider]:
|
|
async def transaction(self) -> AsyncIterator[SQLProvider]:
|
|
async with self.connection.begin_nested():
|
|
async with self.connection.begin_nested():
|
|
yield self
|
|
yield self
|
|
-
|
|
|
|
-
|
|
|
|
-class FakeSQLDatabase(SQLProvider):
|
|
|
|
- def __init__(self):
|
|
|
|
- self.queries: List[List[Executable]] = []
|
|
|
|
- self.result = mock.Mock(return_value=[])
|
|
|
|
-
|
|
|
|
- async def execute(self, query: Executable) -> List[Json]:
|
|
|
|
- self.queries.append([query])
|
|
|
|
- return self.result()
|
|
|
|
-
|
|
|
|
- @asynccontextmanager
|
|
|
|
- async def transaction(self) -> AsyncIterator["SQLProvider"]:
|
|
|
|
- x = FakeSQLTransaction(result=self.result)
|
|
|
|
- self.queries.append(x.queries)
|
|
|
|
- yield x
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class FakeSQLTransaction(SQLProvider):
|
|
|
|
- def __init__(self, result: mock.Mock):
|
|
|
|
- self.queries: List[Executable] = []
|
|
|
|
- self.result = result
|
|
|
|
-
|
|
|
|
- async def execute(self, query: Executable) -> List[Json]:
|
|
|
|
- self.queries.append(query)
|
|
|
|
- return self.result()
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):
|
|
|
|
- """There are two ways of 'binding' parameters (for testing!):
|
|
|
|
-
|
|
|
|
- literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)
|
|
|
|
- literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.
|
|
|
|
- """
|
|
|
|
- assert isinstance(q, Executable)
|
|
|
|
- compiled = q.compile(
|
|
|
|
- compile_kwargs={"literal_binds": literal_binds},
|
|
|
|
- dialect=postgresql.dialect(),
|
|
|
|
- )
|
|
|
|
- if not literal_binds:
|
|
|
|
- actual = str(compiled) % compiled.params
|
|
|
|
- else:
|
|
|
|
- actual = str(compiled)
|
|
|
|
- actual = actual.replace("\n", "").replace(" ", " ")
|
|
|
|
- assert actual == expected
|
|
|