|
|
@@ -5,6 +5,7 @@ from typing import Any
|
|
|
from typing import AsyncIterator
|
|
|
from typing import Dict
|
|
|
from typing import List
|
|
|
+from typing import Optional
|
|
|
|
|
|
from sqlalchemy import text
|
|
|
from sqlalchemy.exc import DBAPIError
|
|
|
@@ -26,7 +27,7 @@ def is_serialization_error(e: DBAPIError) -> bool:
|
|
|
class SQLProvider(ABC):
|
|
|
@abstractmethod
|
|
|
async def execute(
|
|
|
- self, query: Executable, bind_params: Dict[str, Any] = None
|
|
|
+ self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
) -> List[Json]:
|
|
|
pass
|
|
|
|
|
|
@@ -50,7 +51,7 @@ class SQLDatabase(SQLProvider):
|
|
|
self.engine.sync_engine.dispose()
|
|
|
|
|
|
async def execute(
|
|
|
- self, query: Executable, bind_params: Dict[str, Any] = None
|
|
|
+ self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
) -> List[Json]:
|
|
|
async with self.transaction() as transaction:
|
|
|
return await transaction.execute(query, bind_params)
|
|
|
@@ -88,7 +89,7 @@ class SQLTransaction(SQLProvider):
|
|
|
self.connection = connection
|
|
|
|
|
|
async def execute(
|
|
|
- self, query: Executable, bind_params: Dict[str, Any] = None
|
|
|
+ self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
) -> List[Json]:
|
|
|
try:
|
|
|
result = await self.connection.execute(query, bind_params)
|