sql_provider.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import re
  2. from abc import ABC
  3. from abc import abstractmethod
  4. from contextlib import asynccontextmanager
  5. from typing import Any
  6. from typing import AsyncIterator
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Sequence
  11. from sqlalchemy import text
  12. from sqlalchemy.exc import DBAPIError
  13. from sqlalchemy.ext.asyncio import AsyncConnection
  14. from sqlalchemy.ext.asyncio import AsyncEngine
  15. from sqlalchemy.ext.asyncio import create_async_engine
  16. from sqlalchemy.sql import Executable
  17. from clean_python import AlreadyExists
  18. from clean_python import Conflict
  19. from clean_python import Json
  20. __all__ = ["SQLProvider", "SQLDatabase"]
  21. UNIQUE_VIOLATION_DETAIL_REGEX = re.compile(
  22. r"DETAIL:\s*Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists"
  23. )
  24. def maybe_raise_conflict(e: DBAPIError) -> None:
  25. # https://www.postgresql.org/docs/current/errcodes-appendix.html
  26. if e.orig.pgcode == "40001": # serialization_failure
  27. raise Conflict("could not execute query due to concurrent update")
  28. def maybe_raise_already_exists(e: DBAPIError) -> None:
  29. # https://www.postgresql.org/docs/current/errcodes-appendix.html
  30. if e.orig.pgcode == "23505": # unique_violation
  31. match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.orig.args[0].split("\n")[-1])
  32. if match:
  33. raise AlreadyExists(key=match["key"], value=match["value"])
  34. else:
  35. raise AlreadyExists()
  36. class SQLProvider(ABC):
  37. @abstractmethod
  38. async def execute(
  39. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  40. ) -> List[Json]:
  41. pass
  42. @asynccontextmanager
  43. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  44. raise NotImplementedError()
  45. yield
  46. class SQLDatabase(SQLProvider):
  47. engine: AsyncEngine
  48. def __init__(self, url: str, **kwargs):
  49. kwargs.setdefault("isolation_level", "REPEATABLE READ")
  50. self.engine = create_async_engine(url, **kwargs)
  51. async def dispose(self) -> None:
  52. await self.engine.dispose()
  53. def dispose_sync(self) -> None:
  54. self.engine.sync_engine.dispose()
  55. async def execute(
  56. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  57. ) -> List[Json]:
  58. async with self.transaction() as transaction:
  59. return await transaction.execute(query, bind_params)
  60. @asynccontextmanager
  61. async def transaction(self) -> AsyncIterator[SQLProvider]:
  62. async with self.engine.connect() as connection:
  63. async with connection.begin():
  64. yield SQLTransaction(connection)
  65. @asynccontextmanager
  66. async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
  67. async with self.engine.connect() as connection:
  68. async with connection.begin() as transaction:
  69. yield SQLTransaction(connection)
  70. await transaction.rollback()
  71. async def _execute_autocommit(self, query: Executable) -> None:
  72. engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
  73. async with engine.connect() as connection:
  74. await connection.execute(query)
  75. async def create_database(self, name: str) -> None:
  76. await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
  77. async def create_extension(self, name: str) -> None:
  78. await self._execute_autocommit(text(f"CREATE EXTENSION IF NOT EXISTS {name}"))
  79. async def drop_database(self, name: str) -> None:
  80. await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
  81. async def truncate_tables(self, names: Sequence[str]) -> None:
  82. quoted = [f'"{x}"' for x in names]
  83. await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}"))
  84. class SQLTransaction(SQLProvider):
  85. def __init__(self, connection: AsyncConnection):
  86. self.connection = connection
  87. async def execute(
  88. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  89. ) -> List[Json]:
  90. try:
  91. result = await self.connection.execute(query, bind_params)
  92. except DBAPIError as e:
  93. maybe_raise_conflict(e)
  94. maybe_raise_already_exists(e)
  95. raise e
  96. # _asdict() is a documented method of a NamedTuple
  97. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
  98. return [x._asdict() for x in result.fetchall()]
  99. @asynccontextmanager
  100. async def transaction(self) -> AsyncIterator[SQLProvider]:
  101. async with self.connection.begin_nested():
  102. yield self