Explorar o código

Add SQLDatabase.truncate_tables (#14)

Casper van der Wel hai 1 ano
pai
achega
ef76b33080
Modificáronse 2 ficheiros con 16 adicións e 3 borrados
  1. 5 0
      clean_python/sql/sql_provider.py
  2. 11 3
      integration_tests/test_sql_database.py

+ 5 - 0
clean_python/sql/sql_provider.py

@@ -6,6 +6,7 @@ from typing import AsyncIterator
 from typing import Dict
 from typing import List
 from typing import Optional
+from typing import Sequence
 
 from sqlalchemy import text
 from sqlalchemy.exc import DBAPIError
@@ -83,6 +84,10 @@ class SQLDatabase(SQLProvider):
     async def drop_database(self, name: str) -> None:
         await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
 
+    async def truncate_tables(self, names: Sequence[str]) -> None:
+        quoted = [f'"{x}"' for x in names]
+        await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}"))
+
 
 class SQLTransaction(SQLProvider):
     def __init__(self, connection: AsyncConnection):

+ 11 - 3
integration_tests/test_sql_database.py

@@ -59,10 +59,10 @@ async def database(postgres_url):
 
 
 @pytest.fixture
-async def database_with_cleanup(database):
-    await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id"))
+async def database_with_cleanup(database: SQLDatabase):
+    await database.truncate_tables(["test_model"])
     yield database
-    await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id"))
+    await database.truncate_tables(["test_model"])
 
 
 @pytest.fixture
@@ -367,3 +367,11 @@ async def test_count(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
 async def test_exists(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
     actual = await sql_gateway.exists(filters)
     assert actual == expected
+
+
+async def test_truncate(database: SQLDatabase, obj):
+    gateway = TstSQLGateway(database)
+    await gateway.add(obj)
+    assert await gateway.exists([])
+    await database.truncate_tables(["test_model"])
+    assert not await gateway.exists([])