|
|
@@ -1,8 +1,11 @@
|
|
|
from abc import ABC
|
|
|
from abc import abstractmethod
|
|
|
from contextlib import asynccontextmanager
|
|
|
+from typing import Any
|
|
|
from typing import AsyncIterator
|
|
|
+from typing import Dict
|
|
|
from typing import List
|
|
|
+from typing import Optional
|
|
|
|
|
|
from sqlalchemy import text
|
|
|
from sqlalchemy.exc import DBAPIError
|
|
|
@@ -23,7 +26,9 @@ def is_serialization_error(e: DBAPIError) -> bool:
|
|
|
|
|
|
class SQLProvider(ABC):
|
|
|
@abstractmethod
|
|
|
- async def execute(self, query: Executable) -> List[Json]:
|
|
|
+ async def execute(
|
|
|
+ self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
+ ) -> List[Json]:
|
|
|
pass
|
|
|
|
|
|
@asynccontextmanager
|
|
|
@@ -45,9 +50,11 @@ class SQLDatabase(SQLProvider):
|
|
|
def dispose_sync(self) -> None:
|
|
|
self.engine.sync_engine.dispose()
|
|
|
|
|
|
- async def execute(self, query: Executable) -> List[Json]:
|
|
|
+ async def execute(
|
|
|
+ self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
+ ) -> List[Json]:
|
|
|
async with self.transaction() as transaction:
|
|
|
- return await transaction.execute(query)
|
|
|
+ return await transaction.execute(query, bind_params)
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def transaction(self) -> AsyncIterator[SQLProvider]:
|
|
|
@@ -81,9 +88,11 @@ class SQLTransaction(SQLProvider):
|
|
|
def __init__(self, connection: AsyncConnection):
|
|
|
self.connection = connection
|
|
|
|
|
|
- async def execute(self, query: Executable) -> List[Json]:
|
|
|
+ async def execute(
|
|
|
+ self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
+ ) -> List[Json]:
|
|
|
try:
|
|
|
- result = await self.connection.execute(query)
|
|
|
+ result = await self.connection.execute(query, bind_params)
|
|
|
except DBAPIError as e:
|
|
|
if is_serialization_error(e):
|
|
|
raise Conflict(str(e))
|