sql_gateway.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # (c) Nelen & Schuurmans
  2. from contextlib import asynccontextmanager
  3. from datetime import datetime
  4. from typing import AsyncIterator
  5. from typing import Callable
  6. from typing import List
  7. from typing import Optional
  8. from typing import TypeVar
  9. import inject
  10. from sqlalchemy import and_
  11. from sqlalchemy import asc
  12. from sqlalchemy import delete
  13. from sqlalchemy import desc
  14. from sqlalchemy import func
  15. from sqlalchemy import select
  16. from sqlalchemy import Table
  17. from sqlalchemy import true
  18. from sqlalchemy import update
  19. from sqlalchemy.dialects.postgresql import insert
  20. from sqlalchemy.sql import Executable
  21. from sqlalchemy.sql.expression import ColumnElement
  22. from sqlalchemy.sql.expression import false
  23. from clean_python import Conflict
  24. from clean_python import ctx
  25. from clean_python import DoesNotExist
  26. from clean_python import Filter
  27. from clean_python import Gateway
  28. from clean_python import Id
  29. from clean_python import Json
  30. from clean_python import PageOptions
  31. from .sql_provider import SQLDatabase
  32. from .sql_provider import SQLProvider
  33. __all__ = ["SQLGateway"]
  34. T = TypeVar("T", bound="SQLGateway")
  35. class SQLGateway(Gateway):
  36. table: Table
  37. nested: bool
  38. multitenant: bool
  39. def __init__(
  40. self,
  41. provider_override: Optional[SQLProvider] = None,
  42. nested: bool = False,
  43. ):
  44. self.provider_override = provider_override
  45. self.nested = nested
  46. @property
  47. def provider(self):
  48. return self.provider_override or inject.instance(SQLDatabase)
  49. def __init_subclass__(cls, table: Table, multitenant: bool = False) -> None:
  50. cls.table = table
  51. if multitenant and not hasattr(table.c, "tenant"):
  52. raise ValueError("Can't use a multitenant SQLGateway without tenant column")
  53. cls.multitenant = multitenant
  54. super().__init_subclass__()
  55. def rows_to_dict(self, rows: List[Json]) -> List[Json]:
  56. return rows
  57. def dict_to_row(self, obj: Json) -> Json:
  58. known = {c.key for c in self.table.c}
  59. result = {k: obj[k] for k in obj.keys() if k in known}
  60. if "id" in result and result["id"] is None:
  61. del result["id"]
  62. if self.multitenant:
  63. result["tenant"] = self.current_tenant
  64. return result
  65. @asynccontextmanager
  66. async def transaction(self: T) -> AsyncIterator[T]:
  67. if self.nested:
  68. yield self
  69. else:
  70. async with self.provider.transaction() as provider:
  71. yield self.__class__(provider, nested=True)
  72. @property
  73. def current_tenant(self) -> Optional[int]:
  74. if not self.multitenant:
  75. return None
  76. if ctx.tenant is None:
  77. raise RuntimeError(f"{self.__class__} requires a tenant in the context")
  78. return ctx.tenant.id
  79. async def get_related(self, items: List[Json]) -> None:
  80. pass
  81. async def set_related(self, item: Json, result: Json) -> None:
  82. pass
  83. async def execute(self, query: Executable) -> List[Json]:
  84. assert self.nested
  85. return self.rows_to_dict(await self.provider.execute(query))
  86. async def add(self, item: Json) -> Json:
  87. query = (
  88. insert(self.table).values(**self.dict_to_row(item)).returning(self.table)
  89. )
  90. async with self.transaction() as transaction:
  91. (result,) = await transaction.execute(query)
  92. await transaction.set_related(item, result)
  93. return result
  94. async def update(
  95. self, item: Json, if_unmodified_since: Optional[datetime] = None
  96. ) -> Json:
  97. id_ = item.get("id")
  98. if id_ is None:
  99. raise DoesNotExist("record", id_)
  100. q = self._id_filter_to_sql(id_)
  101. if if_unmodified_since is not None:
  102. q &= self.table.c.updated_at == if_unmodified_since
  103. query = (
  104. update(self.table)
  105. .where(q)
  106. .values(**self.dict_to_row(item))
  107. .returning(self.table)
  108. )
  109. async with self.transaction() as transaction:
  110. result = await transaction.execute(query)
  111. if not result:
  112. if if_unmodified_since is not None:
  113. # note: the get() is to maybe raise DoesNotExist
  114. if await self.get(id_):
  115. raise Conflict()
  116. raise DoesNotExist("record", id_)
  117. await transaction.set_related(item, result[0])
  118. return result[0]
  119. async def _select_for_update(self, id: Id) -> Json:
  120. async with self.transaction() as transaction:
  121. result = await transaction.execute(
  122. select(self.table).with_for_update().where(self._id_filter_to_sql(id)),
  123. )
  124. if not result:
  125. raise DoesNotExist("record", id)
  126. await transaction.get_related(result)
  127. return result[0]
  128. async def update_transactional(self, id: Id, func: Callable[[Json], Json]) -> Json:
  129. async with self.transaction() as transaction:
  130. existing = await transaction._select_for_update(id)
  131. updated = func(existing)
  132. return await transaction.update(updated)
  133. async def upsert(self, item: Json) -> Json:
  134. if item.get("id") is None:
  135. return await self.add(item)
  136. values = self.dict_to_row(item)
  137. query = (
  138. insert(self.table)
  139. .values(**values)
  140. .on_conflict_do_update(
  141. index_elements=["id", "tenant"] if self.multitenant else ["id"],
  142. set_=values,
  143. )
  144. .returning(self.table)
  145. )
  146. async with self.transaction() as transaction:
  147. result = await transaction.execute(query)
  148. await transaction.set_related(item, result[0])
  149. return result[0]
  150. async def remove(self, id: Id) -> bool:
  151. query = (
  152. delete(self.table)
  153. .where(self._id_filter_to_sql(id))
  154. .returning(self.table.c.id)
  155. )
  156. async with self.transaction() as transaction:
  157. result = await transaction.execute(query)
  158. return bool(result)
  159. def _filter_to_sql(self, filter: Filter) -> ColumnElement:
  160. try:
  161. column = getattr(self.table.c, filter.field)
  162. except AttributeError:
  163. return false()
  164. if len(filter.values) == 0:
  165. return false()
  166. elif len(filter.values) == 1:
  167. return column == filter.values[0]
  168. else:
  169. return column.in_(filter.values)
  170. def _filters_to_sql(self, filters: List[Filter]) -> ColumnElement:
  171. qs = [self._filter_to_sql(x) for x in filters]
  172. if self.multitenant:
  173. qs.append(self.table.c.tenant == self.current_tenant)
  174. return and_(*qs)
  175. def _id_filter_to_sql(self, id: Id) -> ColumnElement:
  176. return self._filters_to_sql([Filter(field="id", values=[id])])
  177. async def filter(
  178. self, filters: List[Filter], params: Optional[PageOptions] = None
  179. ) -> List[Json]:
  180. query = select(self.table).where(self._filters_to_sql(filters))
  181. if params is not None:
  182. sort = asc(params.order_by) if params.ascending else desc(params.order_by)
  183. query = query.order_by(sort).limit(params.limit).offset(params.offset)
  184. async with self.transaction() as transaction:
  185. result = await transaction.execute(query)
  186. await transaction.get_related(result)
  187. return result
  188. async def count(self, filters: List[Filter]) -> int:
  189. query = (
  190. select(func.count().label("count"))
  191. .select_from(self.table)
  192. .where(self._filters_to_sql(filters))
  193. )
  194. async with self.transaction() as transaction:
  195. return (await transaction.execute(query))[0]["count"]
  196. async def exists(self, filters: List[Filter]) -> bool:
  197. query = (
  198. select(true().label("exists"))
  199. .select_from(self.table)
  200. .where(self._filters_to_sql(filters))
  201. .limit(1)
  202. )
  203. async with self.transaction() as transaction:
  204. return len(await transaction.execute(query)) > 0
  205. async def _get_related_one_to_many(
  206. self,
  207. items: List[Json],
  208. field_name: str,
  209. fk_name: str,
  210. ) -> None:
  211. """Fetch related objects for `items` and add them inplace.
  212. The result is `items` having an additional field containing a list of related
  213. objects which were retrieved from self in 1 SELECT query.
  214. Args:
  215. items: The items for which to fetch related objects. Changed inplace.
  216. field_name: The key in item to put the fetched related objects into.
  217. fk_name: The column name on the related object that refers to item["id"]
  218. Example:
  219. Writer has a one-to-many relation to books.
  220. >>> writers = [{"id": 2, "name": "John Doe"}]
  221. >>> _get_related_one_to_many(
  222. items=writers,
  223. related_gateway=BookSQLGateway,
  224. field_name="books",
  225. fk_name="writer_id",
  226. )
  227. >>> writers[0]
  228. {
  229. "id": 2,
  230. "name": "John Doe",
  231. "books": [
  232. {
  233. "id": 1",
  234. "title": "How to write an ORM",
  235. "writer_id": 2
  236. }
  237. ]
  238. }
  239. """
  240. assert not self.multitenant
  241. for x in items:
  242. x[field_name] = []
  243. item_lut = {x["id"]: x for x in items}
  244. related_objs = await self.filter(
  245. [Filter(field=fk_name, values=list(item_lut.keys()))]
  246. )
  247. for related_obj in related_objs:
  248. item_lut[related_obj[fk_name]][field_name].append(related_obj)
  249. async def _set_related_one_to_many(
  250. self,
  251. item: Json,
  252. result: Json,
  253. field_name: str,
  254. fk_name: str,
  255. ) -> None:
  256. """Set related objects for `item`
  257. This method first fetches the current situation and then adds / updates / removes
  258. where appropriate.
  259. Args:
  260. item: The item for which to set related objects.
  261. result: The dictionary to put the resulting (added / updated) objects into
  262. field_name: The key in result to put the (added / updated) related objects into.
  263. fk_name: The column name on the related object that refers to item["id"]
  264. Example:
  265. Writer has a one-to-many relation to books.
  266. >>> writer = {"id": 2, "name": "John Doe", "books": {"title": "Foo"}}
  267. >>> _set_related_one_to_many(
  268. item=writer,
  269. result=writer,
  270. related_gateway=BookSQLGateway,
  271. field_name="books",
  272. fk_name="writer_id",
  273. )
  274. >>> result
  275. {
  276. "id": 2,
  277. "name": "John Doe",
  278. "books": [
  279. {
  280. "id": 1",
  281. "title": "Foo",
  282. "writer_id": 2
  283. }
  284. ]
  285. }
  286. """
  287. assert not self.multitenant
  288. # list existing related objects
  289. existing_lut = {
  290. x["id"]: x
  291. for x in await self.filter([Filter(field=fk_name, values=[result["id"]])])
  292. }
  293. # add / update them where necessary
  294. returned = []
  295. for new_value in item.get(field_name, []):
  296. new_value = {fk_name: result["id"], **new_value}
  297. existing = existing_lut.pop(new_value.get("id"), None)
  298. if existing is None:
  299. returned.append(await self.add(new_value))
  300. elif new_value == existing:
  301. returned.append(existing)
  302. else:
  303. returned.append(await self.update(new_value))
  304. result[field_name] = returned
  305. # remove remaining
  306. for to_remove in existing_lut:
  307. assert await self.remove(to_remove)