Quellcode durchsuchen

[Done] Use asyncpg directly in SQLProvider (#49)

Casper van der Wel vor 10 Monaten
Ursprung
Commit
00058d255c

+ 5 - 2
CHANGES.md

@@ -1,10 +1,13 @@
 # Changelog of clean-python
 
 
-0.9.7 (unreleased)
+0.10.0 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Changed the internals of SQLProvider: asyncpg is now used directly for
+  connection pooling, transaction management, query execution and parameter
+  binding. This removes overhead from SQL query execution and prevents the
+  use of greenlets.
 
 
 0.9.6 (2023-12-20)

+ 74 - 45
clean_python/sql/sql_provider.py

@@ -8,12 +8,12 @@ from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Sequence
+from typing import Tuple
 
+import asyncpg
+from async_lru import alru_cache
 from sqlalchemy import text
-from sqlalchemy.exc import DBAPIError
-from sqlalchemy.ext.asyncio import AsyncConnection
-from sqlalchemy.ext.asyncio import AsyncEngine
-from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.dialects.postgresql.asyncpg import dialect as asyncpg_dialect
 from sqlalchemy.sql import Executable
 
 from clean_python import AlreadyExists
@@ -24,27 +24,38 @@ __all__ = ["SQLProvider", "SQLDatabase"]
 
 
 UNIQUE_VIOLATION_DETAIL_REGEX = re.compile(
-    r"DETAIL:\s*Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists"
+    r"Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists"
 )
+DIALECT = asyncpg_dialect()
 
 
-def maybe_raise_conflict(e: DBAPIError) -> None:
-    # https://www.postgresql.org/docs/current/errcodes-appendix.html
-    if e.orig.pgcode == "40001":  # serialization_failure
-        raise Conflict("could not execute query due to concurrent update")
-
-
-def maybe_raise_already_exists(e: DBAPIError) -> None:
-    # https://www.postgresql.org/docs/current/errcodes-appendix.html
-    if e.orig.pgcode == "23505":  # unique_violation
-        match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.orig.args[0].split("\n")[-1])
-        if match:
-            raise AlreadyExists(key=match["key"], value=match["value"])
-        else:
-            raise AlreadyExists()
+def convert_unique_violation_error(
+    e: asyncpg.exceptions.UniqueViolationError,
+) -> AlreadyExists:
+    match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.detail)
+    if match:
+        return AlreadyExists(key=match["key"], value=match["value"])
+    else:
+        return AlreadyExists()
 
 
 class SQLProvider(ABC):
+    def compile(
+        self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
+    ) -> Tuple[Any, ...]:
+        # Rendering SQLAlchemy expressions to SQL, see:
+        # - https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html
+        compiled = query.compile(
+            dialect=DIALECT, compile_kwargs={"render_postcompile": True}
+        )
+        params = (
+            compiled.params
+            if bind_params is None
+            else {**compiled.params, **bind_params}
+        )
+        # add params in positional order
+        return (str(compiled),) + tuple(params[k] for k in compiled.positiontup)
+
     @abstractmethod
     async def execute(
         self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
@@ -58,17 +69,29 @@ class SQLProvider(ABC):
 
 
 class SQLDatabase(SQLProvider):
-    engine: AsyncEngine
-
-    def __init__(self, url: str, **kwargs):
-        kwargs.setdefault("isolation_level", "REPEATABLE READ")
-        self.engine = create_async_engine(url, **kwargs)
+    def __init__(
+        self, url: str, *, isolation_level: str = "repeatable_read", pool_size: int = 1
+    ):
+        # Note: disable JIT because it amakes the initial queries very slow
+        # see https://github.com/MagicStack/asyncpg/issues/530
+        if "://" in url:
+            url = url.split("://")[1]
+        self.url = url
+        self.pool_size = pool_size
+        self.isolation_level = isolation_level
+
+    @alru_cache
+    async def get_pool(self):
+        return await asyncpg.create_pool(
+            f"postgresql://{self.url}",
+            server_settings={"jit": "off"},
+            min_size=1,
+            max_size=self.pool_size,
+        )
 
     async def dispose(self) -> None:
-        await self.engine.dispose()
-
-    def dispose_sync(self) -> None:
-        self.engine.sync_engine.dispose()
+        pool = await self.get_pool()
+        await pool.close()
 
     async def execute(
         self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
@@ -78,21 +101,29 @@ class SQLDatabase(SQLProvider):
 
     @asynccontextmanager
     async def transaction(self) -> AsyncIterator[SQLProvider]:
-        async with self.engine.connect() as connection:
-            async with connection.begin():
+        pool = await self.get_pool()
+        connection: asyncpg.Connection
+        async with pool.acquire() as connection:
+            async with connection.transaction(isolation=self.isolation_level):
                 yield SQLTransaction(connection)
 
     @asynccontextmanager
     async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
-        async with self.engine.connect() as connection:
-            async with connection.begin() as transaction:
+        pool = await self.get_pool()
+        connection: asyncpg.Connection
+        async with pool.acquire() as connection:
+            transaction = connection.transaction()
+            await transaction.start()
+            try:
                 yield SQLTransaction(connection)
+            finally:
                 await transaction.rollback()
 
     async def _execute_autocommit(self, query: Executable) -> None:
-        engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
-        async with engine.connect() as connection:
-            await connection.execute(query)
+        pool = await self.get_pool()
+        connection: asyncpg.Connection
+        async with pool.acquire() as connection:
+            await connection.execute(*self.compile(query))
 
     async def create_database(self, name: str) -> None:
         await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
@@ -109,23 +140,21 @@ class SQLDatabase(SQLProvider):
 
 
 class SQLTransaction(SQLProvider):
-    def __init__(self, connection: AsyncConnection):
+    def __init__(self, connection: asyncpg.Connection):
         self.connection = connection
 
     async def execute(
         self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
     ) -> List[Json]:
         try:
-            result = await self.connection.execute(query, bind_params)
-        except DBAPIError as e:
-            maybe_raise_conflict(e)
-            maybe_raise_already_exists(e)
-            raise e
-        # _asdict() is a documented method of a NamedTuple
-        # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
-        return [x._asdict() for x in result.fetchall()]
+            result = await self.connection.fetch(*self.compile(query, bind_params))
+        except asyncpg.exceptions.UniqueViolationError as e:
+            raise convert_unique_violation_error(e)
+        except asyncpg.exceptions.SerializationError:
+            raise Conflict("could not execute query due to concurrent update")
+        return list(map(dict, result))
 
     @asynccontextmanager
     async def transaction(self) -> AsyncIterator[SQLProvider]:
-        async with self.connection.begin_nested():
+        async with self.connection.transaction():
             yield self

+ 15 - 25
integration_tests/test_sql_database.py

@@ -6,6 +6,7 @@ from datetime import timezone
 from unittest import mock
 
 import pytest
+from asyncpg.exceptions import NotNullViolationError
 from sqlalchemy import Boolean
 from sqlalchemy import Column
 from sqlalchemy import DateTime
@@ -14,8 +15,7 @@ from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import Table
 from sqlalchemy import Text
-from sqlalchemy.exc import IntegrityError
-from sqlalchemy.pool import NullPool
+from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.sql import text
 
 from clean_python import AlreadyExists
@@ -47,26 +47,25 @@ insert_query = text(
 update_query = text("UPDATE test_model SET t='bar' WHERE id=:id RETURNING t")
 
 
-@pytest.fixture(scope="session")
-def dburl(postgres_url):
-    return f"postgresql+asyncpg://{postgres_url}"
-
-
 @pytest.fixture(scope="session")
 def dbname():
     return "cleanpython_test"
 
 
 @pytest.fixture(scope="session")
-async def database(dburl, dbname):
-    root_provider = SQLDatabase(f"{dburl}/")
+async def database(postgres_url, dbname):
+    root_provider = SQLDatabase(f"{postgres_url}/")
     await root_provider.drop_database(dbname)
     await root_provider.create_database(dbname)
-    provider = SQLDatabase(f"{dburl}/{dbname}")
-    async with provider.engine.begin() as conn:
+    await root_provider.dispose()
+    engine = create_async_engine(f"postgresql+asyncpg://{postgres_url}/{dbname}")
+    async with engine.begin() as conn:
         await conn.run_sync(test_model.metadata.drop_all)
         await conn.run_sync(test_model.metadata.create_all)
-    yield SQLDatabase(f"{dburl}/{dbname}")
+    await engine.dispose()
+    yield SQLDatabase(
+        f"{postgres_url}/{dbname}", pool_size=2
+    )  # pool_size=2 for Conflict test
 
 
 @pytest.fixture
@@ -145,17 +144,8 @@ async def test_testing_transaction_rollback(database_with_cleanup):
     assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
 
 
-@pytest.fixture
-async def database_no_pool_no_cache(dburl, dbname):
-    db = SQLDatabase(
-        f"{dburl}/{dbname}?prepared_statement_cache_size=0", poolclass=NullPool
-    )
-    yield db
-    await db.truncate_tables(["test_model"])
-
-
 async def test_handle_serialization_error(
-    database_no_pool_no_cache: SQLDatabase, record_id: int
+    database_with_cleanup: SQLDatabase, record_id: int
 ):
     """Typical 'lost update' situation will result in a Conflict error
 
@@ -169,13 +159,13 @@ async def test_handle_serialization_error(
 
     async def update(sleep_before=0.0, sleep_after=0.0):
         await asyncio.sleep(sleep_before)
-        async with database_no_pool_no_cache.transaction() as trans:
+        async with database_with_cleanup.transaction() as trans:
             res = await trans.execute(update_query, bind_params={"id": record_id})
             await asyncio.sleep(sleep_after)
         return res
 
     res1, res2 = await asyncio.gather(
-        update(sleep_after=0.02), update(sleep_before=0.01), return_exceptions=True
+        update(sleep_after=0.1), update(sleep_before=0.05), return_exceptions=True
     )
     assert res1 == [{"t": "bar"}]
     assert isinstance(res2, Conflict)
@@ -279,7 +269,7 @@ async def test_add_integrity_error(sql_gateway, obj, id):
     obj.pop("t")  # will cause the IntegrityError
     if id != "delete":
         obj["id"] = id
-    with pytest.raises(IntegrityError):
+    with pytest.raises(NotNullViolationError):
         await sql_gateway.add(obj)
 
 

+ 1 - 1
tests/fastapi/test_fastapi_access_logger.py

@@ -44,7 +44,7 @@ def req():
         ],
         "state": {},
         "method": "GET",
-        "path": "/rasters",
+        "path": "/v1-beta/rasters",
         "raw_path": b"/v1-beta/rasters",
         "query_string": b"limit=50&offset=0&order_by=id",
         "path_params": {},