Browse Source

Added optional bind params for SQLProvider.execute

Ben van Basten 2 years ago
parent
commit
cb2f4e3744
2 changed files with 15 additions and 6 deletions
  1. 2 1
      CHANGES.md
  2. 13 5
      clean_python/sql/sql_provider.py

+ 2 - 1
CHANGES.md

@@ -4,7 +4,8 @@
 ## 0.3.5 (unreleased)
 ---------------------
 
-- Nothing changed yet.
+- Added optional bind parameters for `execute` in `SQLProvider`,
+  `SQLDatabase` and `SQLTransaction`.
 
 
 ## 0.3.4 (2023-08-28)

+ 13 - 5
clean_python/sql/sql_provider.py

@@ -1,7 +1,9 @@
 from abc import ABC
 from abc import abstractmethod
 from contextlib import asynccontextmanager
+from typing import Any
 from typing import AsyncIterator
+from typing import Dict
 from typing import List
 
 from sqlalchemy import text
@@ -23,7 +25,9 @@ def is_serialization_error(e: DBAPIError) -> bool:
 
 class SQLProvider(ABC):
     @abstractmethod
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, bind_params: Dict[str, Any] = None
+    ) -> List[Json]:
         pass
 
     @asynccontextmanager
@@ -45,9 +49,11 @@ class SQLDatabase(SQLProvider):
     def dispose_sync(self) -> None:
         self.engine.sync_engine.dispose()
 
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, bind_params: Dict[str, Any] = None
+    ) -> List[Json]:
         async with self.transaction() as transaction:
-            return await transaction.execute(query)
+            return await transaction.execute(query, bind_params)
 
     @asynccontextmanager
     async def transaction(self) -> AsyncIterator[SQLProvider]:
@@ -81,9 +87,11 @@ class SQLTransaction(SQLProvider):
     def __init__(self, connection: AsyncConnection):
         self.connection = connection
 
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, bind_params: Dict[str, Any] = None
+    ) -> List[Json]:
         try:
-            result = await self.connection.execute(query)
+            result = await self.connection.execute(query, bind_params)
         except DBAPIError as e:
             if is_serialization_error(e):
                 raise Conflict(str(e))