|
@@ -8,12 +8,12 @@ from typing import Dict
|
|
|
from typing import List
|
|
|
from typing import Optional
|
|
|
from typing import Sequence
|
|
|
+from typing import Tuple
|
|
|
|
|
|
+import asyncpg
|
|
|
+from async_lru import alru_cache
|
|
|
from sqlalchemy import text
|
|
|
-from sqlalchemy.exc import DBAPIError
|
|
|
-from sqlalchemy.ext.asyncio import AsyncConnection
|
|
|
-from sqlalchemy.ext.asyncio import AsyncEngine
|
|
|
-from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
+from sqlalchemy.dialects.postgresql.asyncpg import dialect as asyncpg_dialect
|
|
|
from sqlalchemy.sql import Executable
|
|
|
|
|
|
from clean_python import AlreadyExists
|
|
@@ -24,27 +24,38 @@ __all__ = ["SQLProvider", "SQLDatabase"]
|
|
|
|
|
|
|
|
|
UNIQUE_VIOLATION_DETAIL_REGEX = re.compile(
|
|
|
- r"DETAIL:\s*Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists"
|
|
|
+ r"Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists"
|
|
|
)
|
|
|
+DIALECT = asyncpg_dialect()
|
|
|
|
|
|
|
|
|
-def maybe_raise_conflict(e: DBAPIError) -> None:
|
|
|
- # https://www.postgresql.org/docs/current/errcodes-appendix.html
|
|
|
- if e.orig.pgcode == "40001": # serialization_failure
|
|
|
- raise Conflict("could not execute query due to concurrent update")
|
|
|
-
|
|
|
-
|
|
|
-def maybe_raise_already_exists(e: DBAPIError) -> None:
|
|
|
- # https://www.postgresql.org/docs/current/errcodes-appendix.html
|
|
|
- if e.orig.pgcode == "23505": # unique_violation
|
|
|
- match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.orig.args[0].split("\n")[-1])
|
|
|
- if match:
|
|
|
- raise AlreadyExists(key=match["key"], value=match["value"])
|
|
|
- else:
|
|
|
- raise AlreadyExists()
|
|
|
+def convert_unique_violation_error(
|
|
|
+ e: asyncpg.exceptions.UniqueViolationError,
|
|
|
+) -> AlreadyExists:
|
|
|
+ match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.detail)
|
|
|
+ if match:
|
|
|
+ return AlreadyExists(key=match["key"], value=match["value"])
|
|
|
+ else:
|
|
|
+ return AlreadyExists()
|
|
|
|
|
|
|
|
|
class SQLProvider(ABC):
|
|
|
+ def compile(
|
|
|
+ self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
+ ) -> Tuple[Any, ...]:
|
|
|
+ # Rendering SQLAlchemy expressions to SQL, see:
|
|
|
+ # - https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html
|
|
|
+ compiled = query.compile(
|
|
|
+ dialect=DIALECT, compile_kwargs={"render_postcompile": True}
|
|
|
+ )
|
|
|
+ params = (
|
|
|
+ compiled.params
|
|
|
+ if bind_params is None
|
|
|
+ else {**compiled.params, **bind_params}
|
|
|
+ )
|
|
|
+ # add params in positional order
|
|
|
+ return (str(compiled),) + tuple(params[k] for k in compiled.positiontup)
|
|
|
+
|
|
|
@abstractmethod
|
|
|
async def execute(
|
|
|
self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
@@ -58,17 +69,29 @@ class SQLProvider(ABC):
|
|
|
|
|
|
|
|
|
class SQLDatabase(SQLProvider):
|
|
|
- engine: AsyncEngine
|
|
|
-
|
|
|
- def __init__(self, url: str, **kwargs):
|
|
|
- kwargs.setdefault("isolation_level", "REPEATABLE READ")
|
|
|
- self.engine = create_async_engine(url, **kwargs)
|
|
|
+ def __init__(
|
|
|
+ self, url: str, *, isolation_level: str = "repeatable_read", pool_size: int = 1
|
|
|
+ ):
|
|
|
+ # Note: disable JIT because it amakes the initial queries very slow
|
|
|
+ # see https://github.com/MagicStack/asyncpg/issues/530
|
|
|
+ if "://" in url:
|
|
|
+ url = url.split("://")[1]
|
|
|
+ self.url = url
|
|
|
+ self.pool_size = pool_size
|
|
|
+ self.isolation_level = isolation_level
|
|
|
+
|
|
|
+ @alru_cache
|
|
|
+ async def get_pool(self):
|
|
|
+ return await asyncpg.create_pool(
|
|
|
+ f"postgresql://{self.url}",
|
|
|
+ server_settings={"jit": "off"},
|
|
|
+ min_size=1,
|
|
|
+ max_size=self.pool_size,
|
|
|
+ )
|
|
|
|
|
|
async def dispose(self) -> None:
|
|
|
- await self.engine.dispose()
|
|
|
-
|
|
|
- def dispose_sync(self) -> None:
|
|
|
- self.engine.sync_engine.dispose()
|
|
|
+ pool = await self.get_pool()
|
|
|
+ await pool.close()
|
|
|
|
|
|
async def execute(
|
|
|
self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
@@ -78,21 +101,29 @@ class SQLDatabase(SQLProvider):
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def transaction(self) -> AsyncIterator[SQLProvider]:
|
|
|
- async with self.engine.connect() as connection:
|
|
|
- async with connection.begin():
|
|
|
+ pool = await self.get_pool()
|
|
|
+ connection: asyncpg.Connection
|
|
|
+ async with pool.acquire() as connection:
|
|
|
+ async with connection.transaction(isolation=self.isolation_level):
|
|
|
yield SQLTransaction(connection)
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
|
|
|
- async with self.engine.connect() as connection:
|
|
|
- async with connection.begin() as transaction:
|
|
|
+ pool = await self.get_pool()
|
|
|
+ connection: asyncpg.Connection
|
|
|
+ async with pool.acquire() as connection:
|
|
|
+ transaction = connection.transaction()
|
|
|
+ await transaction.start()
|
|
|
+ try:
|
|
|
yield SQLTransaction(connection)
|
|
|
+ finally:
|
|
|
await transaction.rollback()
|
|
|
|
|
|
async def _execute_autocommit(self, query: Executable) -> None:
|
|
|
- engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
|
|
|
- async with engine.connect() as connection:
|
|
|
- await connection.execute(query)
|
|
|
+ pool = await self.get_pool()
|
|
|
+ connection: asyncpg.Connection
|
|
|
+ async with pool.acquire() as connection:
|
|
|
+ await connection.execute(*self.compile(query))
|
|
|
|
|
|
async def create_database(self, name: str) -> None:
|
|
|
await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
|
|
@@ -109,23 +140,21 @@ class SQLDatabase(SQLProvider):
|
|
|
|
|
|
|
|
|
class SQLTransaction(SQLProvider):
|
|
|
- def __init__(self, connection: AsyncConnection):
|
|
|
+ def __init__(self, connection: asyncpg.Connection):
|
|
|
self.connection = connection
|
|
|
|
|
|
async def execute(
|
|
|
self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
|
|
|
) -> List[Json]:
|
|
|
try:
|
|
|
- result = await self.connection.execute(query, bind_params)
|
|
|
- except DBAPIError as e:
|
|
|
- maybe_raise_conflict(e)
|
|
|
- maybe_raise_already_exists(e)
|
|
|
- raise e
|
|
|
- # _asdict() is a documented method of a NamedTuple
|
|
|
- # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
|
|
|
- return [x._asdict() for x in result.fetchall()]
|
|
|
+ result = await self.connection.fetch(*self.compile(query, bind_params))
|
|
|
+ except asyncpg.exceptions.UniqueViolationError as e:
|
|
|
+ raise convert_unique_violation_error(e)
|
|
|
+ except asyncpg.exceptions.SerializationError:
|
|
|
+ raise Conflict("could not execute query due to concurrent update")
|
|
|
+ return list(map(dict, result))
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def transaction(self) -> AsyncIterator[SQLProvider]:
|
|
|
- async with self.connection.begin_nested():
|
|
|
+ async with self.connection.transaction():
|
|
|
yield self
|