|
@@ -1,5 +1,6 @@
|
|
# -*- coding: utf-8 -*-
|
|
# -*- coding: utf-8 -*-
|
|
# (c) Nelen & Schuurmans
|
|
# (c) Nelen & Schuurmans
|
|
|
|
+import asyncio
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
from datetime import timezone
|
|
from datetime import timezone
|
|
from unittest import mock
|
|
from unittest import mock
|
|
@@ -14,6 +15,7 @@ from sqlalchemy import MetaData
|
|
from sqlalchemy import Table
|
|
from sqlalchemy import Table
|
|
from sqlalchemy import Text
|
|
from sqlalchemy import Text
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
+from sqlalchemy.pool import NullPool
|
|
from sqlalchemy.sql import text
|
|
from sqlalchemy.sql import text
|
|
|
|
|
|
from clean_python import AlreadyExists
|
|
from clean_python import AlreadyExists
|
|
@@ -42,12 +44,21 @@ insert_query = text(
|
|
"VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
|
|
"VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
|
|
"RETURNING id"
|
|
"RETURNING id"
|
|
)
|
|
)
|
|
|
|
+update_query = text("UPDATE test_model SET t='bar' WHERE id=:id RETURNING t")
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
@pytest.fixture(scope="session")
|
|
-async def database(postgres_url):
|
|
|
|
- dburl = f"postgresql+asyncpg://{postgres_url}"
|
|
|
|
- dbname = "cleanpython_test"
|
|
|
|
|
|
+def dburl(postgres_url):
|
|
|
|
+ return f"postgresql+asyncpg://{postgres_url}"
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@pytest.fixture(scope="session")
|
|
|
|
+def dbname():
|
|
|
|
+ return "cleanpython_test"
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@pytest.fixture(scope="session")
|
|
|
|
+async def database(dburl, dbname):
|
|
root_provider = SQLDatabase(f"{dburl}/")
|
|
root_provider = SQLDatabase(f"{dburl}/")
|
|
await root_provider.drop_database(dbname)
|
|
await root_provider.drop_database(dbname)
|
|
await root_provider.create_database(dbname)
|
|
await root_provider.create_database(dbname)
|
|
@@ -71,6 +82,12 @@ async def transaction_with_cleanup(database_with_cleanup):
|
|
yield trans
|
|
yield trans
|
|
|
|
|
|
|
|
|
|
|
|
+@pytest.fixture
|
|
|
|
+async def record_id(database_with_cleanup: SQLDatabase) -> int:
|
|
|
|
+ record = await database_with_cleanup.execute(insert_query)
|
|
|
|
+ return record[0]["id"]
|
|
|
|
+
|
|
|
|
+
|
|
async def test_execute(database_with_cleanup):
|
|
async def test_execute(database_with_cleanup):
|
|
db = database_with_cleanup
|
|
db = database_with_cleanup
|
|
await db.execute(insert_query)
|
|
await db.execute(insert_query)
|
|
@@ -128,6 +145,61 @@ async def test_testing_transaction_rollback(database_with_cleanup):
|
|
assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
|
|
assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
|
|
|
|
|
|
|
|
|
|
|
|
+@pytest.fixture
|
|
|
|
+async def database_no_pool_no_cache(dburl, dbname):
|
|
|
|
+ db = SQLDatabase(
|
|
|
|
+ f"{dburl}/{dbname}?prepared_statement_cache_size=0", poolclass=NullPool
|
|
|
|
+ )
|
|
|
|
+ yield db
|
|
|
|
+ await db.truncate_tables(["test_model"])
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def test_handle_serialization_error(
|
|
|
|
+ database_no_pool_no_cache: SQLDatabase, record_id: int
|
|
|
|
+):
|
|
|
|
+ """Typical 'lost update' situation will result in a Conflict error
|
|
|
|
+
|
|
|
|
+ 1> BEGIN
|
|
|
|
+ 1> UPDATE ... WHERE id=1
|
|
|
|
+ 2> BEGIN
|
|
|
|
+ 2> UPDATE ... WHERE id=1 # transaction 2 will wait until transaction 1 finishes
|
|
|
|
+ 1> COMMIT
|
|
|
|
+ 2> will raise SerializationError
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ async def update(sleep_before=0.0, sleep_after=0.0):
|
|
|
|
+ await asyncio.sleep(sleep_before)
|
|
|
|
+ async with database_no_pool_no_cache.transaction() as trans:
|
|
|
|
+ res = await trans.execute(update_query, bind_params={"id": record_id})
|
|
|
|
+ await asyncio.sleep(sleep_after)
|
|
|
|
+ return res
|
|
|
|
+
|
|
|
|
+ res1, res2 = await asyncio.gather(
|
|
|
|
+ update(sleep_after=0.02), update(sleep_before=0.01), return_exceptions=True
|
|
|
|
+ )
|
|
|
|
+ assert res1 == [{"t": "bar"}]
|
|
|
|
+ assert isinstance(res2, Conflict)
|
|
|
|
+ assert str(res2) == "could not execute query due to concurrent update"
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def test_handle_integrity_error(
|
|
|
|
+ database_with_cleanup: SQLDatabase, record_id: int
|
|
|
|
+):
|
|
|
|
+ """Insert a record with an id that already exists"""
|
|
|
|
+ insert_query_with_id = text(
|
|
|
|
+ "INSERT INTO test_model (id, t, f, b, updated_at) "
|
|
|
|
+ "VALUES (:id, 'foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
|
|
|
|
+ "RETURNING id"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ with pytest.raises(
|
|
|
|
+ AlreadyExists, match=f"record with id={record_id} already exists"
|
|
|
|
+ ):
|
|
|
|
+ await database_with_cleanup.execute(
|
|
|
|
+ insert_query_with_id, bind_params={"id": record_id}
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
### SQLGateway integration tests
|
|
### SQLGateway integration tests
|
|
|
|
|
|
|
|
|
|
@@ -196,7 +268,9 @@ async def test_add(sql_gateway, test_transaction, obj):
|
|
|
|
|
|
|
|
|
|
async def test_add_id_exists(sql_gateway, obj_in_db):
|
|
async def test_add_id_exists(sql_gateway, obj_in_db):
|
|
- with pytest.raises(AlreadyExists):
|
|
|
|
|
|
+ with pytest.raises(
|
|
|
|
+ AlreadyExists, match=f"record with id={obj_in_db['id']} already exists"
|
|
|
|
+ ):
|
|
await sql_gateway.add(obj_in_db)
|
|
await sql_gateway.add(obj_in_db)
|
|
|
|
|
|
|
|
|