Просмотр исходного кода

Handle SQL unique constraint violations and test serialization errors (#30)

Casper van der Wel 1 год назад
Родитель
Сommit
4cb8b6eb0f

+ 2 - 2
clean_python/base/domain/exceptions.py

@@ -42,8 +42,8 @@ class Conflict(Exception):
 
 
 class AlreadyExists(Conflict):
-    def __init__(self, id: Optional[int] = None):
-        super().__init__(f"record with id={id} already exists")
+    def __init__(self, value: Any = None, key: str = "id"):
+        super().__init__(f"record with {key}={value} already exists")
 
 
 class PreconditionFailed(Exception):

+ 1 - 17
clean_python/sql/sql_gateway.py

@@ -18,12 +18,10 @@ from sqlalchemy import Table
 from sqlalchemy import true
 from sqlalchemy import update
 from sqlalchemy.dialects.postgresql import insert
-from sqlalchemy.exc import IntegrityError
 from sqlalchemy.sql import Executable
 from sqlalchemy.sql.expression import ColumnElement
 from sqlalchemy.sql.expression import false
 
-from clean_python import AlreadyExists
 from clean_python import Conflict
 from clean_python import ctx
 from clean_python import DoesNotExist
@@ -39,14 +37,6 @@ from .sql_provider import SQLProvider
 __all__ = ["SQLGateway"]
 
 
-def _is_unique_violation_error_id(e: IntegrityError, id: int):
-    # sqlalchemy wraps the asyncpg error
-    msg = e.orig.args[0]
-    return ("duplicate key value violates unique constraint" in msg) and (
-        f"Key (id)=({id}) already exists." in msg
-    )
-
-
 T = TypeVar("T", bound="SQLGateway")
 
 
@@ -117,13 +107,7 @@ class SQLGateway(Gateway):
             insert(self.table).values(**self.dict_to_row(item)).returning(self.table)
         )
         async with self.transaction() as transaction:
-            try:
-                (result,) = await transaction.execute(query)
-            except IntegrityError as e:
-                id_ = item.get("id")
-                if id_ is not None and _is_unique_violation_error_id(e, id_):
-                    raise AlreadyExists(id_)
-                raise
+            (result,) = await transaction.execute(query)
             await transaction.set_related(item, result)
         return result
 

+ 25 - 7
clean_python/sql/sql_provider.py

@@ -1,3 +1,4 @@
+import re
 from abc import ABC
 from abc import abstractmethod
 from contextlib import asynccontextmanager
@@ -15,14 +16,32 @@ from sqlalchemy.ext.asyncio import AsyncEngine
 from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.sql import Executable
 
+from clean_python import AlreadyExists
 from clean_python import Conflict
 from clean_python import Json
 
 __all__ = ["SQLProvider", "SQLDatabase"]
 
 
-def is_serialization_error(e: DBAPIError) -> bool:
-    return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
+UNIQUE_VIOLATION_DETAIL_REGEX = re.compile(
+    r"DETAIL:\s*Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists"
+)
+
+
+def maybe_raise_conflict(e: DBAPIError) -> None:
+    # https://www.postgresql.org/docs/current/errcodes-appendix.html
+    if e.orig.pgcode == "40001":  # serialization_failure
+        raise Conflict("could not execute query due to concurrent update")
+
+
+def maybe_raise_already_exists(e: DBAPIError) -> None:
+    # https://www.postgresql.org/docs/current/errcodes-appendix.html
+    if e.orig.pgcode == "23505":  # unique_violation
+        match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.orig.args[0].split("\n")[-1])
+        if match:
+            raise AlreadyExists(key=match["key"], value=match["value"])
+        else:
+            raise AlreadyExists()
 
 
 class SQLProvider(ABC):
@@ -42,7 +61,7 @@ class SQLDatabase(SQLProvider):
     engine: AsyncEngine
 
     def __init__(self, url: str, **kwargs):
-        kwargs.setdefault("isolation_level", "READ COMMITTED")
+        kwargs.setdefault("isolation_level", "REPEATABLE READ")
         self.engine = create_async_engine(url, **kwargs)
 
     async def dispose(self) -> None:
@@ -99,10 +118,9 @@ class SQLTransaction(SQLProvider):
         try:
             result = await self.connection.execute(query, bind_params)
         except DBAPIError as e:
-            if is_serialization_error(e):
-                raise Conflict(str(e))
-            else:
-                raise e
+            maybe_raise_conflict(e)
+            maybe_raise_already_exists(e)
+            raise e
         # _asdict() is a documented method of a NamedTuple
         # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
         return [x._asdict() for x in result.fetchall()]

+ 78 - 4
integration_tests/test_sql_database.py

@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # (c) Nelen & Schuurmans
+import asyncio
 from datetime import datetime
 from datetime import timezone
 from unittest import mock
@@ -14,6 +15,7 @@ from sqlalchemy import MetaData
 from sqlalchemy import Table
 from sqlalchemy import Text
 from sqlalchemy.exc import IntegrityError
+from sqlalchemy.pool import NullPool
 from sqlalchemy.sql import text
 
 from clean_python import AlreadyExists
@@ -42,12 +44,21 @@ insert_query = text(
     "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
     "RETURNING id"
 )
+update_query = text("UPDATE test_model SET t='bar' WHERE id=:id RETURNING t")
 
 
 @pytest.fixture(scope="session")
-async def database(postgres_url):
-    dburl = f"postgresql+asyncpg://{postgres_url}"
-    dbname = "cleanpython_test"
+def dburl(postgres_url):
+    return f"postgresql+asyncpg://{postgres_url}"
+
+
+@pytest.fixture(scope="session")
+def dbname():
+    return "cleanpython_test"
+
+
+@pytest.fixture(scope="session")
+async def database(dburl, dbname):
     root_provider = SQLDatabase(f"{dburl}/")
     await root_provider.drop_database(dbname)
     await root_provider.create_database(dbname)
@@ -71,6 +82,12 @@ async def transaction_with_cleanup(database_with_cleanup):
         yield trans
 
 
+@pytest.fixture
+async def record_id(database_with_cleanup: SQLDatabase) -> int:
+    record = await database_with_cleanup.execute(insert_query)
+    return record[0]["id"]
+
+
 async def test_execute(database_with_cleanup):
     db = database_with_cleanup
     await db.execute(insert_query)
@@ -128,6 +145,61 @@ async def test_testing_transaction_rollback(database_with_cleanup):
     assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
 
 
+@pytest.fixture
+async def database_no_pool_no_cache(dburl, dbname):
+    db = SQLDatabase(
+        f"{dburl}/{dbname}?prepared_statement_cache_size=0", poolclass=NullPool
+    )
+    yield db
+    await db.truncate_tables(["test_model"])
+
+
+async def test_handle_serialization_error(
+    database_no_pool_no_cache: SQLDatabase, record_id: int
+):
+    """Typical 'lost update' situation will result in a Conflict error
+
+    1> BEGIN
+    1> UPDATE ... WHERE id=1
+    2> BEGIN
+    2> UPDATE ... WHERE id=1   # transaction 2 will wait until transaction 1 finishes
+    1> COMMIT
+    2> will raise SerializationError
+    """
+
+    async def update(sleep_before=0.0, sleep_after=0.0):
+        await asyncio.sleep(sleep_before)
+        async with database_no_pool_no_cache.transaction() as trans:
+            res = await trans.execute(update_query, bind_params={"id": record_id})
+            await asyncio.sleep(sleep_after)
+        return res
+
+    res1, res2 = await asyncio.gather(
+        update(sleep_after=0.02), update(sleep_before=0.01), return_exceptions=True
+    )
+    assert res1 == [{"t": "bar"}]
+    assert isinstance(res2, Conflict)
+    assert str(res2) == "could not execute query due to concurrent update"
+
+
+async def test_handle_integrity_error(
+    database_with_cleanup: SQLDatabase, record_id: int
+):
+    """Insert a record with an id that already exists"""
+    insert_query_with_id = text(
+        "INSERT INTO test_model (id, t, f, b, updated_at) "
+        "VALUES (:id, 'foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
+        "RETURNING id"
+    )
+
+    with pytest.raises(
+        AlreadyExists, match=f"record with id={record_id} already exists"
+    ):
+        await database_with_cleanup.execute(
+            insert_query_with_id, bind_params={"id": record_id}
+        )
+
+
 ### SQLGateway integration tests
 
 
@@ -196,7 +268,9 @@ async def test_add(sql_gateway, test_transaction, obj):
 
 
 async def test_add_id_exists(sql_gateway, obj_in_db):
-    with pytest.raises(AlreadyExists):
+    with pytest.raises(
+        AlreadyExists, match=f"record with id={obj_in_db['id']} already exists"
+    ):
         await sql_gateway.add(obj_in_db)