test_sql_gateway.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. from datetime import datetime
  2. from datetime import timezone
  3. from unittest import mock
  4. import pytest
  5. from sqlalchemy import Column
  6. from sqlalchemy import DateTime
  7. from sqlalchemy import ForeignKey
  8. from sqlalchemy import Integer
  9. from sqlalchemy import MetaData
  10. from sqlalchemy import Table
  11. from sqlalchemy import Text
  12. from clean_python import Conflict
  13. from clean_python import DoesNotExist
  14. from clean_python import Filter
  15. from clean_python import PageOptions
  16. from clean_python.sql import SQLGateway
  17. from clean_python.sql.testing import assert_query_equal
  18. from clean_python.sql.testing import FakeSQLDatabase
  19. writer = Table(
  20. "writer",
  21. MetaData(),
  22. Column("id", Integer, primary_key=True, autoincrement=True),
  23. Column("value", Text, nullable=False),
  24. Column("updated_at", DateTime(timezone=True), nullable=False),
  25. )
  26. book = Table(
  27. "book",
  28. MetaData(),
  29. Column("id", Integer, primary_key=True, autoincrement=True),
  30. Column("title", Text, nullable=False),
  31. Column(
  32. "writer_id",
  33. Integer,
  34. ForeignKey("writer.id", ondelete="CASCADE", name="book_writer_id_fkey"),
  35. nullable=False,
  36. ),
  37. )
  38. ALL_FIELDS = "writer.id, writer.value, writer.updated_at"
  39. BOOK_FIELDS = "book.id, book.title, book.writer_id"
  40. class TstSQLGateway(SQLGateway, table=writer):
  41. pass
  42. class TstRelatedSQLGateway(SQLGateway, table=book):
  43. pass
  44. @pytest.fixture
  45. def sql_gateway():
  46. return TstSQLGateway(FakeSQLDatabase())
  47. @pytest.fixture
  48. def related_sql_gateway():
  49. return TstRelatedSQLGateway(FakeSQLDatabase())
  50. @pytest.mark.parametrize(
  51. "filters,sql",
  52. [
  53. ([], ""),
  54. ([Filter(field="value", values=[])], " WHERE false"),
  55. ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
  56. (
  57. [Filter(field="value", values=["foo", "bar"])],
  58. " WHERE writer.value IN ('foo', 'bar')",
  59. ),
  60. ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
  61. (
  62. [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
  63. " WHERE writer.id = 1 AND writer.value = 'foo'",
  64. ),
  65. ],
  66. )
  67. async def test_filter(sql_gateway, filters, sql):
  68. sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
  69. assert await sql_gateway.filter(filters) == [{"id": 2, "value": "foo"}]
  70. assert len(sql_gateway.provider.queries) == 1
  71. assert_query_equal(
  72. sql_gateway.provider.queries[0][0],
  73. f"SELECT {ALL_FIELDS} FROM writer{sql}",
  74. )
  75. @pytest.mark.parametrize(
  76. "page_options,sql",
  77. [
  78. (None, ""),
  79. (
  80. PageOptions(limit=5, order_by="id"),
  81. " ORDER BY writer.id ASC LIMIT 5 OFFSET 0",
  82. ),
  83. (
  84. PageOptions(limit=5, offset=2, order_by="id", ascending=False),
  85. " ORDER BY writer.id DESC LIMIT 5 OFFSET 2",
  86. ),
  87. ],
  88. )
  89. async def test_filter_with_pagination(sql_gateway, page_options, sql):
  90. sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
  91. assert await sql_gateway.filter([], params=page_options) == [
  92. {"id": 2, "value": "foo"}
  93. ]
  94. assert len(sql_gateway.provider.queries) == 1
  95. assert_query_equal(
  96. sql_gateway.provider.queries[0][0],
  97. f"SELECT {ALL_FIELDS} FROM writer{sql}",
  98. )
  99. async def test_filter_with_pagination_and_filter(sql_gateway):
  100. sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
  101. assert await sql_gateway.filter(
  102. [Filter(field="value", values=["foo"])],
  103. params=PageOptions(limit=5, order_by="id"),
  104. ) == [{"id": 2, "value": "foo"}]
  105. assert len(sql_gateway.provider.queries) == 1
  106. assert_query_equal(
  107. sql_gateway.provider.queries[0][0],
  108. (
  109. f"SELECT {ALL_FIELDS} FROM writer "
  110. f"WHERE writer.value = 'foo' "
  111. f"ORDER BY writer.id ASC LIMIT 5 OFFSET 0"
  112. ),
  113. )
  114. @pytest.mark.parametrize(
  115. "filters,sql",
  116. [
  117. ([], ""),
  118. ([Filter(field="value", values=[])], " WHERE false"),
  119. ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
  120. (
  121. [Filter(field="value", values=["foo", "bar"])],
  122. " WHERE writer.value IN ('foo', 'bar')",
  123. ),
  124. ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
  125. (
  126. [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
  127. " WHERE writer.id = 1 AND writer.value = 'foo'",
  128. ),
  129. ],
  130. )
  131. async def test_count(sql_gateway, filters, sql):
  132. sql_gateway.provider.result.return_value = [{"count": 4}]
  133. assert await sql_gateway.count(filters) == 4
  134. assert len(sql_gateway.provider.queries) == 1
  135. assert_query_equal(
  136. sql_gateway.provider.queries[0][0],
  137. f"SELECT count(*) AS count FROM writer{sql}",
  138. )
  139. @mock.patch.object(SQLGateway, "filter")
  140. async def test_get(filter_m, sql_gateway):
  141. filter_m.return_value = [{"id": 2, "value": "foo"}]
  142. assert await sql_gateway.get(2) == filter_m.return_value[0]
  143. assert len(sql_gateway.provider.queries) == 0
  144. filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
  145. @mock.patch.object(SQLGateway, "filter")
  146. async def test_get_does_not_exist(filter_m, sql_gateway):
  147. filter_m.return_value = []
  148. assert await sql_gateway.get(2) is None
  149. assert len(sql_gateway.provider.queries) == 0
  150. filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
  151. @pytest.mark.parametrize(
  152. "record,sql",
  153. [
  154. ({}, "DEFAULT VALUES"),
  155. ({"value": "foo"}, "(value) VALUES ('foo')"),
  156. ({"id": None, "value": "foo"}, "(value) VALUES ('foo')"),
  157. ({"id": 2, "value": "foo"}, "(id, value) VALUES (2, 'foo')"),
  158. ({"value": "foo", "nonexisting": 2}, "(value) VALUES ('foo')"),
  159. ],
  160. )
  161. async def test_add(sql_gateway, record, sql):
  162. records = [{"id": 2, "value": "foo"}]
  163. sql_gateway.provider.result.return_value = records
  164. assert await sql_gateway.add(record) == records[0]
  165. assert len(sql_gateway.provider.queries) == 1
  166. assert_query_equal(
  167. sql_gateway.provider.queries[0][0],
  168. (f"INSERT INTO writer {sql} RETURNING {ALL_FIELDS}"),
  169. )
  170. @pytest.mark.parametrize(
  171. "record,if_unmodified_since,sql",
  172. [
  173. (
  174. {"id": 2, "value": "foo"},
  175. None,
  176. "SET id=2, value='foo' WHERE writer.id = 2",
  177. ),
  178. ({"id": 2, "other": "foo"}, None, "SET id=2 WHERE writer.id = 2"),
  179. (
  180. {"id": 2, "value": "foo"},
  181. datetime(2010, 1, 1, tzinfo=timezone.utc),
  182. (
  183. "SET id=2, value='foo' WHERE writer.id = 2 "
  184. "AND writer.updated_at = '2010-01-01 00:00:00+00:00'"
  185. ),
  186. ),
  187. ],
  188. )
  189. async def test_update(sql_gateway, record, if_unmodified_since, sql):
  190. records = [{"id": 2, "value": "foo"}]
  191. sql_gateway.provider.result.return_value = records
  192. assert await sql_gateway.update(record, if_unmodified_since) == records[0]
  193. assert len(sql_gateway.provider.queries) == 1
  194. assert_query_equal(
  195. sql_gateway.provider.queries[0][0],
  196. (f"UPDATE writer {sql} RETURNING {ALL_FIELDS}"),
  197. )
  198. async def test_update_does_not_exist(sql_gateway):
  199. sql_gateway.provider.result.return_value = []
  200. with pytest.raises(DoesNotExist):
  201. await sql_gateway.update({"id": 2})
  202. assert len(sql_gateway.provider.queries) == 1
  203. @mock.patch.object(SQLGateway, "get")
  204. async def test_update_if_unmodified_since_does_not_exist(get_m, sql_gateway):
  205. get_m.return_value = None
  206. sql_gateway.provider.result.return_value = []
  207. with pytest.raises(DoesNotExist):
  208. await sql_gateway.update(
  209. {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
  210. )
  211. assert len(sql_gateway.provider.queries) == 1
  212. get_m.assert_awaited_once_with(2)
  213. @mock.patch.object(SQLGateway, "get")
  214. async def test_update_if_unmodified_since_conflict(get_m, sql_gateway):
  215. get_m.return_value = {"id": 2, "value": "foo"}
  216. sql_gateway.provider.result.return_value = []
  217. with pytest.raises(Conflict):
  218. await sql_gateway.update(
  219. {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
  220. )
  221. assert len(sql_gateway.provider.queries) == 1
  222. get_m.assert_awaited_once_with(2)
  223. async def test_remove(sql_gateway):
  224. sql_gateway.provider.result.return_value = [{"id": 2}]
  225. assert (await sql_gateway.remove(2)) is True
  226. assert len(sql_gateway.provider.queries) == 1
  227. assert_query_equal(
  228. sql_gateway.provider.queries[0][0],
  229. ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
  230. )
  231. async def test_remove_does_not_exist(sql_gateway):
  232. sql_gateway.provider.result.return_value = []
  233. assert (await sql_gateway.remove(2)) is False
  234. assert len(sql_gateway.provider.queries) == 1
  235. assert_query_equal(
  236. sql_gateway.provider.queries[0][0],
  237. ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
  238. )
  239. async def test_upsert(sql_gateway):
  240. record = {"id": 2, "value": "foo"}
  241. sql_gateway.provider.result.return_value = [record]
  242. assert await sql_gateway.upsert(record) == record
  243. assert len(sql_gateway.provider.queries) == 1
  244. assert_query_equal(
  245. sql_gateway.provider.queries[0][0],
  246. (
  247. f"INSERT INTO writer (id, value) VALUES (2, 'foo') "
  248. f"ON CONFLICT (id) DO UPDATE SET "
  249. f"id = %(param_1)s, value = %(param_2)s "
  250. f"RETURNING {ALL_FIELDS}"
  251. ),
  252. )
  253. @mock.patch.object(SQLGateway, "add")
  254. async def test_upsert_no_id(add_m, sql_gateway):
  255. add_m.return_value = {"id": 5, "value": "foo"}
  256. assert await sql_gateway.upsert({"value": "foo"}) == add_m.return_value
  257. add_m.assert_awaited_once_with({"value": "foo"})
  258. assert len(sql_gateway.provider.queries) == 0
  259. async def test_get_related_one_to_many(related_sql_gateway: SQLGateway):
  260. writers = [{"id": 2}, {"id": 3}]
  261. books = [
  262. {"id": 3, "title": "x", "writer_id": 2},
  263. {"id": 4, "title": "y", "writer_id": 2},
  264. ]
  265. related_sql_gateway.provider.result.return_value = books
  266. await related_sql_gateway._get_related_one_to_many(
  267. items=writers,
  268. field_name="books",
  269. fk_name="writer_id",
  270. )
  271. assert writers == [{"id": 2, "books": books}, {"id": 3, "books": []}]
  272. assert len(related_sql_gateway.provider.queries) == 1
  273. assert_query_equal(
  274. related_sql_gateway.provider.queries[0][0],
  275. (
  276. "SELECT book.id, book.title, book.writer_id FROM book WHERE book.writer_id IN (2, 3)"
  277. ),
  278. )
  279. @pytest.mark.parametrize(
  280. "books,current_books,expected_queries,query_results",
  281. [
  282. # no change
  283. (
  284. [{"id": 3, "title": "x", "writer_id": 2}],
  285. [{"id": 3, "title": "x", "writer_id": 2}],
  286. [],
  287. [],
  288. ),
  289. # added a book (without an id)
  290. (
  291. [{"title": "x", "writer_id": 2}],
  292. [],
  293. [
  294. f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}"
  295. ],
  296. [[{"id": 3, "title": "x", "writer_id": 2}]],
  297. ),
  298. # added a book (with an id)
  299. (
  300. [{"id": 3, "title": "x", "writer_id": 2}],
  301. [],
  302. [
  303. f"INSERT INTO book (id, title, writer_id) VALUES (3, 'x', 2) RETURNING {BOOK_FIELDS}"
  304. ],
  305. [[{"id": 3, "title": "x", "writer_id": 2}]],
  306. ),
  307. # updated a book
  308. (
  309. [{"id": 3, "title": "x", "writer_id": 2}],
  310. [{"id": 3, "title": "a", "writer_id": 2}],
  311. [
  312. f"UPDATE book SET id=3, title='x', writer_id=2 WHERE book.id = 3 RETURNING {BOOK_FIELDS}"
  313. ],
  314. [[{"id": 3, "title": "x", "writer_id": 2}]],
  315. ),
  316. # replaced a book with a new one
  317. (
  318. [{"title": "x", "writer_id": 2}],
  319. [{"id": 15, "title": "a", "writer_id": 2}],
  320. [
  321. f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}",
  322. "DELETE FROM book WHERE book.id = 15 RETURNING book.id",
  323. ],
  324. [[{"id": 3, "title": "x", "writer_id": 2}], [{"id": 15}]],
  325. ),
  326. ],
  327. )
  328. async def test_set_related_one_to_many(
  329. related_sql_gateway: SQLGateway,
  330. books,
  331. current_books,
  332. expected_queries,
  333. query_results,
  334. ):
  335. writer = {"id": 2, "books": books}
  336. related_sql_gateway.provider.result.side_effect = [current_books] + query_results
  337. result = writer.copy()
  338. await related_sql_gateway._set_related_one_to_many(
  339. item=writer,
  340. result=result,
  341. field_name="books",
  342. fk_name="writer_id",
  343. )
  344. assert result == {
  345. "id": 2,
  346. "books": [{"id": 3, "title": "x", "writer_id": 2}],
  347. }
  348. assert len(related_sql_gateway.provider.queries) == len(expected_queries) + 1
  349. assert_query_equal(
  350. related_sql_gateway.provider.queries[0][0],
  351. f"SELECT {BOOK_FIELDS} FROM book WHERE book.writer_id = 2",
  352. )
  353. for (actual_query,), expected_query in zip(
  354. related_sql_gateway.provider.queries[1:], expected_queries
  355. ):
  356. assert_query_equal(actual_query, expected_query)
  357. async def test_update_transactional(sql_gateway):
  358. existing = {"id": 2, "value": "foo"}
  359. expected = {"id": 2, "value": "bar"}
  360. sql_gateway.provider.result.side_effect = ([existing], [expected])
  361. actual = await sql_gateway.update_transactional(
  362. 2, lambda x: {"id": x["id"], "value": "bar"}
  363. )
  364. assert actual == expected
  365. (queries,) = sql_gateway.provider.queries
  366. assert len(queries) == 2
  367. assert_query_equal(
  368. queries[0],
  369. f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 FOR UPDATE",
  370. )
  371. assert_query_equal(
  372. queries[1],
  373. (
  374. f"UPDATE writer SET id=2, value='bar' WHERE writer.id = 2 RETURNING {ALL_FIELDS}"
  375. ),
  376. )
  377. @pytest.mark.parametrize(
  378. "filters,sql",
  379. [
  380. ([], ""),
  381. ([Filter(field="value", values=[])], " WHERE false"),
  382. ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
  383. (
  384. [Filter(field="value", values=["foo", "bar"])],
  385. " WHERE writer.value IN ('foo', 'bar')",
  386. ),
  387. ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
  388. (
  389. [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
  390. " WHERE writer.id = 1 AND writer.value = 'foo'",
  391. ),
  392. ],
  393. )
  394. async def test_exists(sql_gateway, filters, sql):
  395. sql_gateway.provider.result.return_value = [{"exists": True}]
  396. assert await sql_gateway.exists(filters) is True
  397. assert len(sql_gateway.provider.queries) == 1
  398. assert_query_equal(
  399. sql_gateway.provider.queries[0][0],
  400. f"SELECT true AS exists FROM writer{sql} LIMIT 1",
  401. )