test_sql_gateway.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. from datetime import datetime, timezone
  2. from unittest import mock
  3. import pytest
  4. from sqlalchemy import Column, DateTime, ForeignKey, Integer, MetaData, Table, Text
  5. from clean_python import (
  6. assert_query_equal,
  7. Conflict,
  8. DoesNotExist,
  9. FakeSQLDatabase,
  10. Filter,
  11. PageOptions,
  12. SQLGateway,
  13. )
  14. writer = Table(
  15. "writer",
  16. MetaData(),
  17. Column("id", Integer, primary_key=True, autoincrement=True),
  18. Column("value", Text, nullable=False),
  19. Column("updated_at", DateTime(timezone=True), nullable=False),
  20. )
  21. book = Table(
  22. "book",
  23. MetaData(),
  24. Column("id", Integer, primary_key=True, autoincrement=True),
  25. Column("title", Text, nullable=False),
  26. Column(
  27. "writer_id",
  28. Integer,
  29. ForeignKey("writer.id", ondelete="CASCADE", name="book_writer_id_fkey"),
  30. nullable=False,
  31. ),
  32. )
  33. ALL_FIELDS = "writer.id, writer.value, writer.updated_at"
  34. BOOK_FIELDS = "book.id, book.title, book.writer_id"
  35. class TstSQLGateway(SQLGateway, table=writer):
  36. pass
  37. class TstRelatedSQLGateway(SQLGateway, table=book):
  38. pass
  39. @pytest.fixture
  40. def sql_gateway():
  41. return TstSQLGateway(FakeSQLDatabase())
  42. @pytest.fixture
  43. def related_sql_gateway():
  44. return TstRelatedSQLGateway(FakeSQLDatabase())
  45. @pytest.mark.parametrize(
  46. "filters,sql",
  47. [
  48. ([], ""),
  49. ([Filter(field="value", values=[])], " WHERE false"),
  50. ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
  51. (
  52. [Filter(field="value", values=["foo", "bar"])],
  53. " WHERE writer.value IN ('foo', 'bar')",
  54. ),
  55. ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
  56. (
  57. [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
  58. " WHERE writer.id = 1 AND writer.value = 'foo'",
  59. ),
  60. ],
  61. )
  62. async def test_filter(sql_gateway, filters, sql):
  63. sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
  64. assert await sql_gateway.filter(filters) == [{"id": 2, "value": "foo"}]
  65. assert len(sql_gateway.provider.queries) == 1
  66. assert_query_equal(
  67. sql_gateway.provider.queries[0][0],
  68. f"SELECT {ALL_FIELDS} FROM writer{sql}",
  69. )
  70. @pytest.mark.parametrize(
  71. "page_options,sql",
  72. [
  73. (None, ""),
  74. (
  75. PageOptions(limit=5, order_by="id"),
  76. " ORDER BY writer.id ASC LIMIT 5 OFFSET 0",
  77. ),
  78. (
  79. PageOptions(limit=5, offset=2, order_by="id", ascending=False),
  80. " ORDER BY writer.id DESC LIMIT 5 OFFSET 2",
  81. ),
  82. ],
  83. )
  84. async def test_filter_with_pagination(sql_gateway, page_options, sql):
  85. sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
  86. assert await sql_gateway.filter([], params=page_options) == [
  87. {"id": 2, "value": "foo"}
  88. ]
  89. assert len(sql_gateway.provider.queries) == 1
  90. assert_query_equal(
  91. sql_gateway.provider.queries[0][0],
  92. f"SELECT {ALL_FIELDS} FROM writer{sql}",
  93. )
  94. async def test_filter_with_pagination_and_filter(sql_gateway):
  95. sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
  96. assert await sql_gateway.filter(
  97. [Filter(field="value", values=["foo"])],
  98. params=PageOptions(limit=5, order_by="id"),
  99. ) == [{"id": 2, "value": "foo"}]
  100. assert len(sql_gateway.provider.queries) == 1
  101. assert_query_equal(
  102. sql_gateway.provider.queries[0][0],
  103. (
  104. f"SELECT {ALL_FIELDS} FROM writer "
  105. f"WHERE writer.value = 'foo' "
  106. f"ORDER BY writer.id ASC LIMIT 5 OFFSET 0"
  107. ),
  108. )
  109. @pytest.mark.parametrize(
  110. "filters,sql",
  111. [
  112. ([], ""),
  113. ([Filter(field="value", values=[])], " WHERE false"),
  114. ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
  115. (
  116. [Filter(field="value", values=["foo", "bar"])],
  117. " WHERE writer.value IN ('foo', 'bar')",
  118. ),
  119. ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
  120. (
  121. [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
  122. " WHERE writer.id = 1 AND writer.value = 'foo'",
  123. ),
  124. ],
  125. )
  126. async def test_count(sql_gateway, filters, sql):
  127. sql_gateway.provider.result.return_value = [{"count": 4}]
  128. assert await sql_gateway.count(filters) == 4
  129. assert len(sql_gateway.provider.queries) == 1
  130. assert_query_equal(
  131. sql_gateway.provider.queries[0][0],
  132. f"SELECT count(*) AS count FROM writer{sql}",
  133. )
  134. @mock.patch.object(SQLGateway, "filter")
  135. async def test_get(filter_m, sql_gateway):
  136. filter_m.return_value = [{"id": 2, "value": "foo"}]
  137. assert await sql_gateway.get(2) == filter_m.return_value[0]
  138. assert len(sql_gateway.provider.queries) == 0
  139. filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
  140. @mock.patch.object(SQLGateway, "filter")
  141. async def test_get_does_not_exist(filter_m, sql_gateway):
  142. filter_m.return_value = []
  143. assert await sql_gateway.get(2) is None
  144. assert len(sql_gateway.provider.queries) == 0
  145. filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
  146. @pytest.mark.parametrize(
  147. "record,sql",
  148. [
  149. ({}, "DEFAULT VALUES"),
  150. ({"value": "foo"}, "(value) VALUES ('foo')"),
  151. ({"id": None, "value": "foo"}, "(value) VALUES ('foo')"),
  152. ({"id": 2, "value": "foo"}, "(id, value) VALUES (2, 'foo')"),
  153. ({"value": "foo", "nonexisting": 2}, "(value) VALUES ('foo')"),
  154. ],
  155. )
  156. async def test_add(sql_gateway, record, sql):
  157. records = [{"id": 2, "value": "foo"}]
  158. sql_gateway.provider.result.return_value = records
  159. assert await sql_gateway.add(record) == records[0]
  160. assert len(sql_gateway.provider.queries) == 1
  161. assert_query_equal(
  162. sql_gateway.provider.queries[0][0],
  163. (f"INSERT INTO writer {sql} RETURNING {ALL_FIELDS}"),
  164. )
  165. @pytest.mark.parametrize(
  166. "record,if_unmodified_since,sql",
  167. [
  168. (
  169. {"id": 2, "value": "foo"},
  170. None,
  171. "SET id=2, value='foo' WHERE writer.id = 2",
  172. ),
  173. ({"id": 2, "other": "foo"}, None, "SET id=2 WHERE writer.id = 2"),
  174. (
  175. {"id": 2, "value": "foo"},
  176. datetime(2010, 1, 1, tzinfo=timezone.utc),
  177. (
  178. "SET id=2, value='foo' WHERE writer.id = 2 "
  179. "AND writer.updated_at = '2010-01-01 00:00:00+00:00'"
  180. ),
  181. ),
  182. ],
  183. )
  184. async def test_update(sql_gateway, record, if_unmodified_since, sql):
  185. records = [{"id": 2, "value": "foo"}]
  186. sql_gateway.provider.result.return_value = records
  187. assert await sql_gateway.update(record, if_unmodified_since) == records[0]
  188. assert len(sql_gateway.provider.queries) == 1
  189. assert_query_equal(
  190. sql_gateway.provider.queries[0][0],
  191. (f"UPDATE writer {sql} RETURNING {ALL_FIELDS}"),
  192. )
  193. async def test_update_does_not_exist(sql_gateway):
  194. sql_gateway.provider.result.return_value = []
  195. with pytest.raises(DoesNotExist):
  196. await sql_gateway.update({"id": 2})
  197. assert len(sql_gateway.provider.queries) == 1
  198. @mock.patch.object(SQLGateway, "get")
  199. async def test_update_if_unmodified_since_does_not_exist(get_m, sql_gateway):
  200. get_m.return_value = None
  201. sql_gateway.provider.result.return_value = []
  202. with pytest.raises(DoesNotExist):
  203. await sql_gateway.update(
  204. {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
  205. )
  206. assert len(sql_gateway.provider.queries) == 1
  207. get_m.assert_awaited_once_with(2)
  208. @mock.patch.object(SQLGateway, "get")
  209. async def test_update_if_unmodified_since_conflict(get_m, sql_gateway):
  210. get_m.return_value = {"id": 2, "value": "foo"}
  211. sql_gateway.provider.result.return_value = []
  212. with pytest.raises(Conflict):
  213. await sql_gateway.update(
  214. {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
  215. )
  216. assert len(sql_gateway.provider.queries) == 1
  217. get_m.assert_awaited_once_with(2)
  218. async def test_remove(sql_gateway):
  219. sql_gateway.provider.result.return_value = [{"id": 2}]
  220. assert (await sql_gateway.remove(2)) is True
  221. assert len(sql_gateway.provider.queries) == 1
  222. assert_query_equal(
  223. sql_gateway.provider.queries[0][0],
  224. ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
  225. )
  226. async def test_remove_does_not_exist(sql_gateway):
  227. sql_gateway.provider.result.return_value = []
  228. assert (await sql_gateway.remove(2)) is False
  229. assert len(sql_gateway.provider.queries) == 1
  230. assert_query_equal(
  231. sql_gateway.provider.queries[0][0],
  232. ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
  233. )
  234. async def test_upsert(sql_gateway):
  235. record = {"id": 2, "value": "foo"}
  236. sql_gateway.provider.result.return_value = [record]
  237. assert await sql_gateway.upsert(record) == record
  238. assert len(sql_gateway.provider.queries) == 1
  239. assert_query_equal(
  240. sql_gateway.provider.queries[0][0],
  241. (
  242. f"INSERT INTO writer (id, value) VALUES (2, 'foo') "
  243. f"ON CONFLICT (id) DO UPDATE SET "
  244. f"id = %(param_1)s, value = %(param_2)s "
  245. f"RETURNING {ALL_FIELDS}"
  246. ),
  247. )
  248. @mock.patch.object(SQLGateway, "add")
  249. async def test_upsert_no_id(add_m, sql_gateway):
  250. add_m.return_value = {"id": 5, "value": "foo"}
  251. assert await sql_gateway.upsert({"value": "foo"}) == add_m.return_value
  252. add_m.assert_awaited_once_with({"value": "foo"})
  253. assert len(sql_gateway.provider.queries) == 0
  254. async def test_get_related_one_to_many(related_sql_gateway: SQLGateway):
  255. writers = [{"id": 2}, {"id": 3}]
  256. books = [
  257. {"id": 3, "title": "x", "writer_id": 2},
  258. {"id": 4, "title": "y", "writer_id": 2},
  259. ]
  260. related_sql_gateway.provider.result.return_value = books
  261. await related_sql_gateway._get_related_one_to_many(
  262. items=writers,
  263. field_name="books",
  264. fk_name="writer_id",
  265. )
  266. assert writers == [{"id": 2, "books": books}, {"id": 3, "books": []}]
  267. assert len(related_sql_gateway.provider.queries) == 1
  268. assert_query_equal(
  269. related_sql_gateway.provider.queries[0][0],
  270. (
  271. "SELECT book.id, book.title, book.writer_id FROM book WHERE book.writer_id IN (2, 3)"
  272. ),
  273. )
  274. @pytest.mark.parametrize(
  275. "books,current_books,expected_queries,query_results",
  276. [
  277. # no change
  278. (
  279. [{"id": 3, "title": "x", "writer_id": 2}],
  280. [{"id": 3, "title": "x", "writer_id": 2}],
  281. [],
  282. [],
  283. ),
  284. # added a book (without an id)
  285. (
  286. [{"title": "x", "writer_id": 2}],
  287. [],
  288. [
  289. f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}"
  290. ],
  291. [[{"id": 3, "title": "x", "writer_id": 2}]],
  292. ),
  293. # added a book (with an id)
  294. (
  295. [{"id": 3, "title": "x", "writer_id": 2}],
  296. [],
  297. [
  298. f"INSERT INTO book (id, title, writer_id) VALUES (3, 'x', 2) RETURNING {BOOK_FIELDS}"
  299. ],
  300. [[{"id": 3, "title": "x", "writer_id": 2}]],
  301. ),
  302. # updated a book
  303. (
  304. [{"id": 3, "title": "x", "writer_id": 2}],
  305. [{"id": 3, "title": "a", "writer_id": 2}],
  306. [
  307. f"UPDATE book SET id=3, title='x', writer_id=2 WHERE book.id = 3 RETURNING {BOOK_FIELDS}"
  308. ],
  309. [[{"id": 3, "title": "x", "writer_id": 2}]],
  310. ),
  311. # replaced a book with a new one
  312. (
  313. [{"title": "x", "writer_id": 2}],
  314. [{"id": 15, "title": "a", "writer_id": 2}],
  315. [
  316. f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}",
  317. "DELETE FROM book WHERE book.id = 15 RETURNING book.id",
  318. ],
  319. [[{"id": 3, "title": "x", "writer_id": 2}], [{"id": 15}]],
  320. ),
  321. ],
  322. )
  323. async def test_set_related_one_to_many(
  324. related_sql_gateway: SQLGateway,
  325. books,
  326. current_books,
  327. expected_queries,
  328. query_results,
  329. ):
  330. writer = {"id": 2, "books": books}
  331. related_sql_gateway.provider.result.side_effect = [current_books] + query_results
  332. result = writer.copy()
  333. await related_sql_gateway._set_related_one_to_many(
  334. item=writer,
  335. result=result,
  336. field_name="books",
  337. fk_name="writer_id",
  338. )
  339. assert result == {
  340. "id": 2,
  341. "books": [{"id": 3, "title": "x", "writer_id": 2}],
  342. }
  343. assert len(related_sql_gateway.provider.queries) == len(expected_queries) + 1
  344. assert_query_equal(
  345. related_sql_gateway.provider.queries[0][0],
  346. f"SELECT {BOOK_FIELDS} FROM book WHERE book.writer_id = 2",
  347. )
  348. for (actual_query,), expected_query in zip(
  349. related_sql_gateway.provider.queries[1:], expected_queries
  350. ):
  351. assert_query_equal(actual_query, expected_query)
  352. async def test_update_transactional(sql_gateway):
  353. existing = {"id": 2, "value": "foo"}
  354. expected = {"id": 2, "value": "bar"}
  355. sql_gateway.provider.result.side_effect = ([existing], [expected])
  356. actual = await sql_gateway.update_transactional(
  357. 2, lambda x: {"id": x["id"], "value": "bar"}
  358. )
  359. assert actual == expected
  360. (queries,) = sql_gateway.provider.queries
  361. assert len(queries) == 2
  362. assert_query_equal(
  363. queries[0],
  364. f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 FOR UPDATE",
  365. )
  366. assert_query_equal(
  367. queries[1],
  368. (
  369. f"UPDATE writer SET id=2, value='bar' WHERE writer.id = 2 RETURNING {ALL_FIELDS}"
  370. ),
  371. )
  372. @pytest.mark.parametrize(
  373. "filters,sql",
  374. [
  375. ([], ""),
  376. ([Filter(field="value", values=[])], " WHERE false"),
  377. ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
  378. (
  379. [Filter(field="value", values=["foo", "bar"])],
  380. " WHERE writer.value IN ('foo', 'bar')",
  381. ),
  382. ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
  383. (
  384. [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
  385. " WHERE writer.id = 1 AND writer.value = 'foo'",
  386. ),
  387. ],
  388. )
  389. async def test_exists(sql_gateway, filters, sql):
  390. sql_gateway.provider.result.return_value = [{"exists": True}]
  391. assert await sql_gateway.exists(filters) is True
  392. assert len(sql_gateway.provider.queries) == 1
  393. assert_query_equal(
  394. sql_gateway.provider.queries[0][0],
  395. f"SELECT true AS exists FROM writer{sql} LIMIT 1",
  396. )