sql_gateway.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # (c) Nelen & Schuurmans
  2. from contextlib import asynccontextmanager
  3. from datetime import datetime
  4. from typing import AsyncIterator
  5. from typing import Callable
  6. from typing import List
  7. from typing import Optional
  8. from typing import TypeVar
  9. import inject
  10. from sqlalchemy import asc
  11. from sqlalchemy import delete
  12. from sqlalchemy import desc
  13. from sqlalchemy import func
  14. from sqlalchemy import select
  15. from sqlalchemy import Table
  16. from sqlalchemy import true
  17. from sqlalchemy import update
  18. from sqlalchemy.dialects.postgresql import insert
  19. from sqlalchemy.exc import IntegrityError
  20. from sqlalchemy.sql import Executable
  21. from sqlalchemy.sql.expression import ColumnElement
  22. from sqlalchemy.sql.expression import false
  23. from clean_python import AlreadyExists
  24. from clean_python import Conflict
  25. from clean_python import DoesNotExist
  26. from clean_python import Filter
  27. from clean_python import Gateway
  28. from clean_python import Json
  29. from clean_python import PageOptions
  30. from .sql_provider import SQLDatabase
  31. from .sql_provider import SQLProvider
  32. __all__ = ["SQLGateway"]
  33. def _is_unique_violation_error_id(e: IntegrityError, id: int):
  34. # sqlalchemy wraps the asyncpg error
  35. msg = e.orig.args[0]
  36. return ("duplicate key value violates unique constraint" in msg) and (
  37. f"Key (id)=({id}) already exists." in msg
  38. )
  39. T = TypeVar("T", bound="SQLGateway")
  40. class SQLGateway(Gateway):
  41. table: Table
  42. nested: bool
  43. def __init__(
  44. self, provider_override: Optional[SQLProvider] = None, nested: bool = False
  45. ):
  46. self.provider_override = provider_override
  47. self.nested = nested
  48. @property
  49. def provider(self):
  50. return self.provider_override or inject.instance(SQLDatabase)
  51. def __init_subclass__(cls, table: Table) -> None:
  52. cls.table = table
  53. super().__init_subclass__()
  54. def rows_to_dict(self, rows: List[Json]) -> List[Json]:
  55. return rows
  56. def dict_to_row(self, obj: Json) -> Json:
  57. known = {c.key for c in self.table.c}
  58. result = {k: obj[k] for k in obj.keys() if k in known}
  59. if "id" in result and result["id"] is None:
  60. del result["id"]
  61. return result
  62. @asynccontextmanager
  63. async def transaction(self: T) -> AsyncIterator[T]:
  64. if self.nested:
  65. yield self
  66. else:
  67. async with self.provider.transaction() as provider:
  68. yield self.__class__(provider, nested=True)
  69. async def get_related(self, items: List[Json]) -> None:
  70. pass
  71. async def set_related(self, item: Json, result: Json) -> None:
  72. pass
  73. async def execute(self, query: Executable) -> List[Json]:
  74. assert self.nested
  75. return self.rows_to_dict(await self.provider.execute(query))
  76. async def add(self, item: Json) -> Json:
  77. query = (
  78. insert(self.table).values(**self.dict_to_row(item)).returning(self.table)
  79. )
  80. async with self.transaction() as transaction:
  81. try:
  82. (result,) = await transaction.execute(query)
  83. except IntegrityError as e:
  84. id_ = item.get("id")
  85. if id_ is not None and _is_unique_violation_error_id(e, id_):
  86. raise AlreadyExists(id_)
  87. raise
  88. await transaction.set_related(item, result)
  89. return result
  90. async def update(
  91. self, item: Json, if_unmodified_since: Optional[datetime] = None
  92. ) -> Json:
  93. id_ = item.get("id")
  94. if id_ is None:
  95. raise DoesNotExist("record", id_)
  96. q = self.table.c.id == id_
  97. if if_unmodified_since is not None:
  98. q &= self.table.c.updated_at == if_unmodified_since
  99. query = (
  100. update(self.table)
  101. .where(q)
  102. .values(**self.dict_to_row(item))
  103. .returning(self.table)
  104. )
  105. async with self.transaction() as transaction:
  106. result = await transaction.execute(query)
  107. if not result:
  108. if if_unmodified_since is not None:
  109. # note: the get() is to maybe raise DoesNotExist
  110. if await self.get(id_):
  111. raise Conflict()
  112. raise DoesNotExist("record", id_)
  113. await transaction.set_related(item, result[0])
  114. return result[0]
  115. async def _select_for_update(self, id: int) -> Json:
  116. async with self.transaction() as transaction:
  117. result = await transaction.execute(
  118. select(self.table).with_for_update().where(self.table.c.id == id),
  119. )
  120. if not result:
  121. raise DoesNotExist("record", id)
  122. await transaction.get_related(result)
  123. return result[0]
  124. async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
  125. async with self.transaction() as transaction:
  126. existing = await transaction._select_for_update(id)
  127. updated = func(existing)
  128. return await transaction.update(updated)
  129. async def upsert(self, item: Json) -> Json:
  130. if item.get("id") is None:
  131. return await self.add(item)
  132. values = self.dict_to_row(item)
  133. query = (
  134. insert(self.table)
  135. .values(**values)
  136. .on_conflict_do_update(index_elements=["id"], set_=values)
  137. .returning(self.table)
  138. )
  139. async with self.transaction() as transaction:
  140. result = await transaction.execute(query)
  141. await transaction.set_related(item, result[0])
  142. return result[0]
  143. async def remove(self, id) -> bool:
  144. query = (
  145. delete(self.table).where(self.table.c.id == id).returning(self.table.c.id)
  146. )
  147. async with self.transaction() as transaction:
  148. result = await transaction.execute(query)
  149. return bool(result)
  150. def _to_sqlalchemy_expression(self, filter: Filter) -> ColumnElement:
  151. try:
  152. column = getattr(self.table.c, filter.field)
  153. except AttributeError:
  154. return false()
  155. if len(filter.values) == 0:
  156. return false()
  157. elif len(filter.values) == 1:
  158. return column == filter.values[0]
  159. else:
  160. return column.in_(filter.values)
  161. async def filter(
  162. self, filters: List[Filter], params: Optional[PageOptions] = None
  163. ) -> List[Json]:
  164. query = select(self.table).where(
  165. *[self._to_sqlalchemy_expression(x) for x in filters]
  166. )
  167. if params is not None:
  168. sort = asc(params.order_by) if params.ascending else desc(params.order_by)
  169. query = query.order_by(sort).limit(params.limit).offset(params.offset)
  170. async with self.transaction() as transaction:
  171. result = await transaction.execute(query)
  172. await transaction.get_related(result)
  173. return result
  174. async def count(self, filters: List[Filter]) -> int:
  175. query = (
  176. select(func.count().label("count"))
  177. .select_from(self.table)
  178. .where(*[self._to_sqlalchemy_expression(x) for x in filters])
  179. )
  180. async with self.transaction() as transaction:
  181. return (await transaction.execute(query))[0]["count"]
  182. async def exists(self, filters: List[Filter]) -> bool:
  183. query = (
  184. select(true().label("exists"))
  185. .select_from(self.table)
  186. .where(*[self._to_sqlalchemy_expression(x) for x in filters])
  187. .limit(1)
  188. )
  189. async with self.transaction() as transaction:
  190. return len(await transaction.execute(query)) > 0
  191. async def _get_related_one_to_many(
  192. self,
  193. items: List[Json],
  194. field_name: str,
  195. fk_name: str,
  196. ) -> None:
  197. """Fetch related objects for `items` and add them inplace.
  198. The result is `items` having an additional field containing a list of related
  199. objects which were retrieved from self in 1 SELECT query.
  200. Args:
  201. items: The items for which to fetch related objects. Changed inplace.
  202. field_name: The key in item to put the fetched related objects into.
  203. fk_name: The column name on the related object that refers to item["id"]
  204. Example:
  205. Writer has a one-to-many relation to books.
  206. >>> writers = [{"id": 2, "name": "John Doe"}]
  207. >>> _get_related_one_to_many(
  208. items=writers,
  209. related_gateway=BookSQLGateway,
  210. field_name="books",
  211. fk_name="writer_id",
  212. )
  213. >>> writers[0]
  214. {
  215. "id": 2,
  216. "name": "John Doe",
  217. "books": [
  218. {
  219. "id": 1",
  220. "title": "How to write an ORM",
  221. "writer_id": 2
  222. }
  223. ]
  224. }
  225. """
  226. for x in items:
  227. x[field_name] = []
  228. item_lut = {x["id"]: x for x in items}
  229. related_objs = await self.filter(
  230. [Filter(field=fk_name, values=list(item_lut.keys()))]
  231. )
  232. for related_obj in related_objs:
  233. item_lut[related_obj[fk_name]][field_name].append(related_obj)
  234. async def _set_related_one_to_many(
  235. self,
  236. item: Json,
  237. result: Json,
  238. field_name: str,
  239. fk_name: str,
  240. ) -> None:
  241. """Set related objects for `item`
  242. This method first fetches the current situation and then adds / updates / removes
  243. where appropriate.
  244. Args:
  245. item: The item for which to set related objects.
  246. result: The dictionary to put the resulting (added / updated) objects into
  247. field_name: The key in result to put the (added / updated) related objects into.
  248. fk_name: The column name on the related object that refers to item["id"]
  249. Example:
  250. Writer has a one-to-many relation to books.
  251. >>> writer = {"id": 2, "name": "John Doe", "books": {"title": "Foo"}}
  252. >>> _set_related_one_to_many(
  253. item=writer,
  254. result=writer,
  255. related_gateway=BookSQLGateway,
  256. field_name="books",
  257. fk_name="writer_id",
  258. )
  259. >>> result
  260. {
  261. "id": 2,
  262. "name": "John Doe",
  263. "books": [
  264. {
  265. "id": 1",
  266. "title": "Foo",
  267. "writer_id": 2
  268. }
  269. ]
  270. }
  271. """
  272. # list existing related objects
  273. existing_lut = {
  274. x["id"]: x
  275. for x in await self.filter([Filter(field=fk_name, values=[result["id"]])])
  276. }
  277. # add / update them where necessary
  278. returned = []
  279. for new_value in item.get(field_name, []):
  280. new_value = {fk_name: result["id"], **new_value}
  281. existing = existing_lut.pop(new_value.get("id"), None)
  282. if existing is None:
  283. returned.append(await self.add(new_value))
  284. elif new_value == existing:
  285. returned.append(existing)
  286. else:
  287. returned.append(await self.update(new_value))
  288. result[field_name] = returned
  289. # remove remaining
  290. for to_remove in existing_lut:
  291. assert await self.remove(to_remove)