sql_provider.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import re
  2. from abc import ABC
  3. from abc import abstractmethod
  4. from contextlib import asynccontextmanager
  5. from typing import Any
  6. from typing import AsyncIterator
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Sequence
  11. from typing import Tuple
  12. import asyncpg
  13. from async_lru import alru_cache
  14. from sqlalchemy import text
  15. from sqlalchemy.dialects.postgresql.asyncpg import dialect as asyncpg_dialect
  16. from sqlalchemy.sql import Executable
  17. from clean_python import AlreadyExists
  18. from clean_python import Conflict
  19. from clean_python import Json
  20. __all__ = ["SQLProvider", "SQLDatabase"]
  21. UNIQUE_VIOLATION_DETAIL_REGEX = re.compile(
  22. r"Key\s\((?P<key>.*)\)=\((?P<value>.*)\)\s+already exists"
  23. )
  24. DIALECT = asyncpg_dialect()
  25. def convert_unique_violation_error(
  26. e: asyncpg.exceptions.UniqueViolationError,
  27. ) -> AlreadyExists:
  28. match = UNIQUE_VIOLATION_DETAIL_REGEX.match(e.detail)
  29. if match:
  30. return AlreadyExists(key=match["key"], value=match["value"])
  31. else:
  32. return AlreadyExists()
  33. class SQLProvider(ABC):
  34. def compile(
  35. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  36. ) -> Tuple[Any, ...]:
  37. # Rendering SQLAlchemy expressions to SQL, see:
  38. # - https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html
  39. compiled = query.compile(
  40. dialect=DIALECT, compile_kwargs={"render_postcompile": True}
  41. )
  42. params = (
  43. compiled.params
  44. if bind_params is None
  45. else {**compiled.params, **bind_params}
  46. )
  47. # add params in positional order
  48. return (str(compiled),) + tuple(params[k] for k in compiled.positiontup)
  49. @abstractmethod
  50. async def execute(
  51. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  52. ) -> List[Json]:
  53. pass
  54. @asynccontextmanager
  55. async def transaction(self) -> AsyncIterator["SQLProvider"]:
  56. raise NotImplementedError()
  57. yield
  58. class SQLDatabase(SQLProvider):
  59. def __init__(
  60. self, url: str, *, isolation_level: str = "repeatable_read", pool_size: int = 1
  61. ):
  62. # Note: disable JIT because it amakes the initial queries very slow
  63. # see https://github.com/MagicStack/asyncpg/issues/530
  64. if "://" in url:
  65. url = url.split("://")[1]
  66. self.url = url
  67. self.pool_size = pool_size
  68. self.isolation_level = isolation_level
  69. @alru_cache
  70. async def get_pool(self):
  71. return await asyncpg.create_pool(
  72. f"postgresql://{self.url}",
  73. server_settings={"jit": "off"},
  74. min_size=1,
  75. max_size=self.pool_size,
  76. )
  77. async def dispose(self) -> None:
  78. pool = await self.get_pool()
  79. await pool.close()
  80. async def execute(
  81. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  82. ) -> List[Json]:
  83. async with self.transaction() as transaction:
  84. return await transaction.execute(query, bind_params)
  85. @asynccontextmanager
  86. async def transaction(self) -> AsyncIterator[SQLProvider]:
  87. pool = await self.get_pool()
  88. connection: asyncpg.Connection
  89. async with pool.acquire() as connection:
  90. async with connection.transaction(isolation=self.isolation_level):
  91. yield SQLTransaction(connection)
  92. @asynccontextmanager
  93. async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
  94. pool = await self.get_pool()
  95. connection: asyncpg.Connection
  96. async with pool.acquire() as connection:
  97. transaction = connection.transaction()
  98. await transaction.start()
  99. try:
  100. yield SQLTransaction(connection)
  101. finally:
  102. await transaction.rollback()
  103. async def _execute_autocommit(self, query: Executable) -> None:
  104. pool = await self.get_pool()
  105. connection: asyncpg.Connection
  106. async with pool.acquire() as connection:
  107. await connection.execute(*self.compile(query))
  108. async def create_database(self, name: str) -> None:
  109. await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
  110. async def create_extension(self, name: str) -> None:
  111. await self._execute_autocommit(text(f"CREATE EXTENSION IF NOT EXISTS {name}"))
  112. async def drop_database(self, name: str) -> None:
  113. await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
  114. async def truncate_tables(self, names: Sequence[str]) -> None:
  115. quoted = [f'"{x}"' for x in names]
  116. await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}"))
  117. class SQLTransaction(SQLProvider):
  118. def __init__(self, connection: asyncpg.Connection):
  119. self.connection = connection
  120. async def execute(
  121. self, query: Executable, bind_params: Optional[Dict[str, Any]] = None
  122. ) -> List[Json]:
  123. try:
  124. result = await self.connection.fetch(*self.compile(query, bind_params))
  125. except asyncpg.exceptions.UniqueViolationError as e:
  126. raise convert_unique_violation_error(e)
  127. except asyncpg.exceptions.SerializationError:
  128. raise Conflict("could not execute query due to concurrent update")
  129. return list(map(dict, result))
  130. @asynccontextmanager
  131. async def transaction(self) -> AsyncIterator[SQLProvider]:
  132. async with self.connection.transaction():
  133. yield self