Browse Source

Merge pull request #12 from nens/bind_params

Added optional bind params for SQLProvider.execute
Ben van Basten 2 năm trước cách đây
mục cha
commit
f3d891a30f
3 tập tin đã thay đổi với 25 bổ sung8 xóa
  1. 2 1
      CHANGES.md
  2. 14 5
      clean_python/sql/sql_provider.py
  3. 9 2
      clean_python/sql/testing.py

+ 2 - 1
CHANGES.md

@@ -4,7 +4,8 @@
 0.4.1 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Added optional bind parameters for `execute` in `SQLProvider`,
+  `SQLDatabase` and `SQLTransaction`.
 
 
 0.4.0 (2023-08-29)

+ 14 - 5
clean_python/sql/sql_provider.py

@@ -1,8 +1,11 @@
 from abc import ABC
 from abc import abstractmethod
 from contextlib import asynccontextmanager
+from typing import Any
 from typing import AsyncIterator
+from typing import Dict
 from typing import List
+from typing import Optional
 
 from sqlalchemy import text
 from sqlalchemy.exc import DBAPIError
@@ -23,7 +26,9 @@ def is_serialization_error(e: DBAPIError) -> bool:
 
 class SQLProvider(ABC):
     @abstractmethod
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
+    ) -> List[Json]:
         pass
 
     @asynccontextmanager
@@ -45,9 +50,11 @@ class SQLDatabase(SQLProvider):
     def dispose_sync(self) -> None:
         self.engine.sync_engine.dispose()
 
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
+    ) -> List[Json]:
         async with self.transaction() as transaction:
-            return await transaction.execute(query)
+            return await transaction.execute(query, bind_params)
 
     @asynccontextmanager
     async def transaction(self) -> AsyncIterator[SQLProvider]:
@@ -81,9 +88,11 @@ class SQLTransaction(SQLProvider):
     def __init__(self, connection: AsyncConnection):
         self.connection = connection
 
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
+    ) -> List[Json]:
         try:
-            result = await self.connection.execute(query)
+            result = await self.connection.execute(query, bind_params)
         except DBAPIError as e:
             if is_serialization_error(e):
                 raise Conflict(str(e))

+ 9 - 2
clean_python/sql/testing.py

@@ -1,6 +1,9 @@
 from contextlib import asynccontextmanager
+from typing import Any
 from typing import AsyncIterator
+from typing import Dict
 from typing import List
+from typing import Optional
 from unittest import mock
 
 from sqlalchemy.dialects import postgresql
@@ -17,7 +20,9 @@ class FakeSQLDatabase(SQLProvider):
         self.queries: List[List[Executable]] = []
         self.result = mock.Mock(return_value=[])
 
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, _: Optional[Dict[str, Any]] = None
+    ) -> List[Json]:
         self.queries.append([query])
         return self.result()
 
@@ -33,7 +38,9 @@ class FakeSQLTransaction(SQLProvider):
         self.queries: List[Executable] = []
         self.result = result
 
-    async def execute(self, query: Executable) -> List[Json]:
+    async def execute(
+        self, query: Executable, _: Optional[Dict[str, Any]] = None
+    ) -> List[Json]:
         self.queries.append(query)
         return self.result()