sql_gateway.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. from contextlib import asynccontextmanager
  4. from datetime import datetime
  5. from typing import AsyncIterator, Callable, List, Optional, TypeVar
  6. import inject
  7. from sqlalchemy import asc, delete, desc, func, select, Table, true, update
  8. from sqlalchemy.dialects.postgresql import insert
  9. from sqlalchemy.exc import IntegrityError
  10. from sqlalchemy.sql import Executable
  11. from sqlalchemy.sql.expression import ColumnElement, false
  12. from clean_python.base.domain.exceptions import AlreadyExists, Conflict, DoesNotExist
  13. from clean_python.base.infrastructure.gateway import Filter, Gateway, Json
  14. from clean_python.base.domain.pagination import PageOptions
  15. from .sql_provider import SQLDatabase, SQLProvider
  16. def _is_unique_violation_error_id(e: IntegrityError, id: int):
  17. # sqlalchemy wraps the asyncpg error
  18. msg = e.orig.args[0]
  19. return ("duplicate key value violates unique constraint" in msg) and (
  20. f"Key (id)=({id}) already exists." in msg
  21. )
  22. T = TypeVar("T", bound="SQLGateway")
  23. class SQLGateway(Gateway):
  24. table: Table
  25. nested: bool
  26. def __init__(
  27. self, provider_override: Optional[SQLProvider] = None, nested: bool = False
  28. ):
  29. self.provider_override = provider_override
  30. self.nested = nested
  31. @property
  32. def provider(self):
  33. return self.provider_override or inject.instance(SQLDatabase)
  34. def __init_subclass__(cls, table: Table) -> None:
  35. cls.table = table
  36. super().__init_subclass__()
  37. def rows_to_dict(self, rows: List[Json]) -> List[Json]:
  38. return rows
  39. def dict_to_row(self, obj: Json) -> Json:
  40. known = {c.key for c in self.table.c}
  41. result = {k: obj[k] for k in obj.keys() if k in known}
  42. if "id" in result and result["id"] is None:
  43. del result["id"]
  44. return result
  45. @asynccontextmanager
  46. async def transaction(self: T) -> AsyncIterator[T]:
  47. if self.nested:
  48. yield self
  49. else:
  50. async with self.provider.transaction() as provider:
  51. yield self.__class__(provider, nested=True)
  52. async def get_related(self, items: List[Json]) -> None:
  53. pass
  54. async def set_related(self, item: Json, result: Json) -> None:
  55. pass
  56. async def execute(self, query: Executable) -> List[Json]:
  57. assert self.nested
  58. return self.rows_to_dict(await self.provider.execute(query))
  59. async def add(self, item: Json) -> Json:
  60. query = (
  61. insert(self.table).values(**self.dict_to_row(item)).returning(self.table)
  62. )
  63. async with self.transaction() as transaction:
  64. try:
  65. (result,) = await transaction.execute(query)
  66. except IntegrityError as e:
  67. id_ = item.get("id")
  68. if id_ is not None and _is_unique_violation_error_id(e, id_):
  69. raise AlreadyExists(id_)
  70. raise
  71. await transaction.set_related(item, result)
  72. return result
  73. async def update(
  74. self, item: Json, if_unmodified_since: Optional[datetime] = None
  75. ) -> Json:
  76. id_ = item.get("id")
  77. if id_ is None:
  78. raise DoesNotExist("record", id_)
  79. q = self.table.c.id == id_
  80. if if_unmodified_since is not None:
  81. q &= self.table.c.updated_at == if_unmodified_since
  82. query = (
  83. update(self.table)
  84. .where(q)
  85. .values(**self.dict_to_row(item))
  86. .returning(self.table)
  87. )
  88. async with self.transaction() as transaction:
  89. result = await transaction.execute(query)
  90. if not result:
  91. if if_unmodified_since is not None:
  92. # note: the get() is to maybe raise DoesNotExist
  93. if await self.get(id_):
  94. raise Conflict()
  95. raise DoesNotExist("record", id_)
  96. await transaction.set_related(item, result[0])
  97. return result[0]
  98. async def _select_for_update(self, id: int) -> Json:
  99. async with self.transaction() as transaction:
  100. result = await transaction.execute(
  101. select(self.table).with_for_update().where(self.table.c.id == id),
  102. )
  103. if not result:
  104. raise DoesNotExist("record", id)
  105. await transaction.get_related(result)
  106. return result[0]
  107. async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
  108. async with self.transaction() as transaction:
  109. existing = await transaction._select_for_update(id)
  110. updated = func(existing)
  111. return await transaction.update(updated)
  112. async def upsert(self, item: Json) -> Json:
  113. if item.get("id") is None:
  114. return await self.add(item)
  115. values = self.dict_to_row(item)
  116. query = (
  117. insert(self.table)
  118. .values(**values)
  119. .on_conflict_do_update(index_elements=["id"], set_=values)
  120. .returning(self.table)
  121. )
  122. async with self.transaction() as transaction:
  123. result = await transaction.execute(query)
  124. await transaction.set_related(item, result[0])
  125. return result[0]
  126. async def remove(self, id) -> bool:
  127. query = (
  128. delete(self.table).where(self.table.c.id == id).returning(self.table.c.id)
  129. )
  130. async with self.transaction() as transaction:
  131. result = await transaction.execute(query)
  132. return bool(result)
  133. def _to_sqlalchemy_expression(self, filter: Filter) -> ColumnElement:
  134. try:
  135. column = getattr(self.table.c, filter.field)
  136. except AttributeError:
  137. return false()
  138. if len(filter.values) == 0:
  139. return false()
  140. elif len(filter.values) == 1:
  141. return column == filter.values[0]
  142. else:
  143. return column.in_(filter.values)
  144. async def filter(
  145. self, filters: List[Filter], params: Optional[PageOptions] = None
  146. ) -> List[Json]:
  147. query = select(self.table).where(
  148. *[self._to_sqlalchemy_expression(x) for x in filters]
  149. )
  150. if params is not None:
  151. sort = asc(params.order_by) if params.ascending else desc(params.order_by)
  152. query = query.order_by(sort).limit(params.limit).offset(params.offset)
  153. async with self.transaction() as transaction:
  154. result = await transaction.execute(query)
  155. await transaction.get_related(result)
  156. return result
  157. async def count(self, filters: List[Filter]) -> int:
  158. query = (
  159. select(func.count().label("count"))
  160. .select_from(self.table)
  161. .where(*[self._to_sqlalchemy_expression(x) for x in filters])
  162. )
  163. async with self.transaction() as transaction:
  164. return (await transaction.execute(query))[0]["count"]
  165. async def exists(self, filters: List[Filter]) -> bool:
  166. query = (
  167. select(true().label("exists"))
  168. .select_from(self.table)
  169. .where(*[self._to_sqlalchemy_expression(x) for x in filters])
  170. .limit(1)
  171. )
  172. async with self.transaction() as transaction:
  173. return len(await transaction.execute(query)) > 0
  174. async def _get_related_one_to_many(
  175. self,
  176. items: List[Json],
  177. field_name: str,
  178. fk_name: str,
  179. ) -> None:
  180. """Fetch related objects for `items` and add them inplace.
  181. The result is `items` having an additional field containing a list of related
  182. objects which were retrieved from self in 1 SELECT query.
  183. Args:
  184. items: The items for which to fetch related objects. Changed inplace.
  185. field_name: The key in item to put the fetched related objects into.
  186. fk_name: The column name on the related object that refers to item["id"]
  187. Example:
  188. Writer has a one-to-many relation to books.
  189. >>> writers = [{"id": 2, "name": "John Doe"}]
  190. >>> _get_related_one_to_many(
  191. items=writers,
  192. related_gateway=BookSQLGateway,
  193. field_name="books",
  194. fk_name="writer_id",
  195. )
  196. >>> writers[0]
  197. {
  198. "id": 2,
  199. "name": "John Doe",
  200. "books": [
  201. {
  202. "id": 1",
  203. "title": "How to write an ORM",
  204. "writer_id": 2
  205. }
  206. ]
  207. }
  208. """
  209. for x in items:
  210. x[field_name] = []
  211. item_lut = {x["id"]: x for x in items}
  212. related_objs = await self.filter(
  213. [Filter(field=fk_name, values=list(item_lut.keys()))]
  214. )
  215. for related_obj in related_objs:
  216. item_lut[related_obj[fk_name]][field_name].append(related_obj)
  217. async def _set_related_one_to_many(
  218. self,
  219. item: Json,
  220. result: Json,
  221. field_name: str,
  222. fk_name: str,
  223. ) -> None:
  224. """Set related objects for `item`
  225. This method first fetches the current situation and then adds / updates / removes
  226. where appropriate.
  227. Args:
  228. item: The item for which to set related objects.
  229. result: The dictionary to put the resulting (added / updated) objects into
  230. field_name: The key in result to put the (added / updated) related objects into.
  231. fk_name: The column name on the related object that refers to item["id"]
  232. Example:
  233. Writer has a one-to-many relation to books.
  234. >>> writer = {"id": 2, "name": "John Doe", "books": {"title": "Foo"}}
  235. >>> _set_related_one_to_many(
  236. item=writer,
  237. result=writer,
  238. related_gateway=BookSQLGateway,
  239. field_name="books",
  240. fk_name="writer_id",
  241. )
  242. >>> result
  243. {
  244. "id": 2,
  245. "name": "John Doe",
  246. "books": [
  247. {
  248. "id": 1",
  249. "title": "Foo",
  250. "writer_id": 2
  251. }
  252. ]
  253. }
  254. """
  255. # list existing related objects
  256. existing_lut = {
  257. x["id"]: x
  258. for x in await self.filter([Filter(field=fk_name, values=[result["id"]])])
  259. }
  260. # add / update them where necessary
  261. returned = []
  262. for new_value in item.get(field_name, []):
  263. new_value = {fk_name: result["id"], **new_value}
  264. existing = existing_lut.pop(new_value.get("id"), None)
  265. if existing is None:
  266. returned.append(await self.add(new_value))
  267. elif new_value == existing:
  268. returned.append(existing)
  269. else:
  270. returned.append(await self.update(new_value))
  271. result[field_name] = returned
  272. # remove remaining
  273. for to_remove in existing_lut:
  274. assert await self.remove(to_remove)