test_sql_database.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. import asyncio
  4. from datetime import datetime
  5. from datetime import timezone
  6. from unittest import mock
  7. import pytest
  8. from asyncpg.exceptions import NotNullViolationError
  9. from sqlalchemy import Boolean
  10. from sqlalchemy import Column
  11. from sqlalchemy import DateTime
  12. from sqlalchemy import Float
  13. from sqlalchemy import Integer
  14. from sqlalchemy import MetaData
  15. from sqlalchemy import Table
  16. from sqlalchemy import Text
  17. from sqlalchemy.ext.asyncio import create_async_engine
  18. from sqlalchemy.sql import text
  19. from clean_python import AlreadyExists
  20. from clean_python import Conflict
  21. from clean_python import DoesNotExist
  22. from clean_python import Filter
  23. from clean_python.sql import SQLDatabase
  24. from clean_python.sql import SQLGateway
  25. test_model = Table(
  26. "test_model",
  27. MetaData(),
  28. Column("id", Integer, primary_key=True, autoincrement=True),
  29. Column("t", Text, nullable=False),
  30. Column("f", Float, nullable=False),
  31. Column("b", Boolean, nullable=False),
  32. Column("updated_at", DateTime(timezone=True), nullable=False),
  33. Column("n", Float, nullable=True),
  34. )
  35. ### SQLProvider integration tests
  36. count_query = text("SELECT COUNT(*) FROM test_model")
  37. insert_query = text(
  38. "INSERT INTO test_model (t, f, b, updated_at) "
  39. "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  40. "RETURNING id"
  41. )
  42. update_query = text("UPDATE test_model SET t='bar' WHERE id=:id RETURNING t")
  43. @pytest.fixture(scope="session")
  44. def dbname():
  45. return "cleanpython_test"
  46. @pytest.fixture(scope="session")
  47. async def database(postgres_url, dbname):
  48. root_provider = SQLDatabase(f"{postgres_url}/")
  49. await root_provider.drop_database(dbname)
  50. await root_provider.create_database(dbname)
  51. await root_provider.dispose()
  52. engine = create_async_engine(f"postgresql+asyncpg://{postgres_url}/{dbname}")
  53. async with engine.begin() as conn:
  54. await conn.run_sync(test_model.metadata.drop_all)
  55. await conn.run_sync(test_model.metadata.create_all)
  56. await engine.dispose()
  57. yield SQLDatabase(
  58. f"{postgres_url}/{dbname}", pool_size=2
  59. ) # pool_size=2 for Conflict test
  60. @pytest.fixture
  61. async def database_with_cleanup(database: SQLDatabase):
  62. await database.truncate_tables(["test_model"])
  63. yield database
  64. await database.truncate_tables(["test_model"])
  65. @pytest.fixture
  66. async def transaction_with_cleanup(database_with_cleanup):
  67. async with database_with_cleanup.transaction() as trans:
  68. yield trans
  69. @pytest.fixture
  70. async def record_id(database_with_cleanup: SQLDatabase) -> int:
  71. record = await database_with_cleanup.execute(insert_query)
  72. return record[0]["id"]
  73. async def test_execute(database_with_cleanup):
  74. db = database_with_cleanup
  75. await db.execute(insert_query)
  76. assert await db.execute(count_query) == [{"count": 1}]
  77. async def test_transaction_commits(database_with_cleanup):
  78. db = database_with_cleanup
  79. async with db.transaction() as trans:
  80. await trans.execute(insert_query)
  81. assert await db.execute(count_query) == [{"count": 1}]
  82. async def test_transaction_err(database_with_cleanup):
  83. db = database_with_cleanup
  84. await db.execute(insert_query)
  85. with pytest.raises(RuntimeError):
  86. async with db.transaction() as trans:
  87. await trans.execute(insert_query)
  88. raise RuntimeError() # triggers rollback
  89. assert await db.execute(count_query) == [{"count": 1}]
  90. async def test_nested_transaction_commits(transaction_with_cleanup):
  91. db = transaction_with_cleanup
  92. async with db.transaction() as trans:
  93. await trans.execute(insert_query)
  94. assert await db.execute(count_query) == [{"count": 1}]
  95. async def test_nested_transaction_err(transaction_with_cleanup):
  96. db = transaction_with_cleanup
  97. await db.execute(insert_query)
  98. with pytest.raises(RuntimeError):
  99. async with db.transaction() as trans:
  100. await trans.execute(insert_query)
  101. raise RuntimeError() # triggers rollback
  102. assert await db.execute(count_query) == [{"count": 1}]
  103. async def test_testing_transaction_rollback(database_with_cleanup):
  104. async with database_with_cleanup.testing_transaction() as trans:
  105. await trans.execute(insert_query)
  106. assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
  107. async def test_handle_serialization_error(
  108. database_with_cleanup: SQLDatabase, record_id: int
  109. ):
  110. """Typical 'lost update' situation will result in a Conflict error
  111. 1> BEGIN
  112. 1> UPDATE ... WHERE id=1
  113. 2> BEGIN
  114. 2> UPDATE ... WHERE id=1 # transaction 2 will wait until transaction 1 finishes
  115. 1> COMMIT
  116. 2> will raise SerializationError
  117. """
  118. async def update(sleep_before=0.0, sleep_after=0.0):
  119. await asyncio.sleep(sleep_before)
  120. async with database_with_cleanup.transaction() as trans:
  121. res = await trans.execute(update_query, bind_params={"id": record_id})
  122. await asyncio.sleep(sleep_after)
  123. return res
  124. res1, res2 = await asyncio.gather(
  125. update(sleep_after=0.1), update(sleep_before=0.05), return_exceptions=True
  126. )
  127. assert res1 == [{"t": "bar"}]
  128. assert isinstance(res2, Conflict)
  129. assert str(res2) == "could not execute query due to concurrent update"
  130. async def test_handle_integrity_error(
  131. database_with_cleanup: SQLDatabase, record_id: int
  132. ):
  133. """Insert a record with an id that already exists"""
  134. insert_query_with_id = text(
  135. "INSERT INTO test_model (id, t, f, b, updated_at) "
  136. "VALUES (:id, 'foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  137. "RETURNING id"
  138. )
  139. with pytest.raises(
  140. AlreadyExists, match=f"record with id={record_id} already exists"
  141. ):
  142. await database_with_cleanup.execute(
  143. insert_query_with_id, bind_params={"id": record_id}
  144. )
  145. ### SQLGateway integration tests
  146. class TstSQLGateway(SQLGateway, table=test_model):
  147. pass
  148. @pytest.fixture
  149. async def test_transaction(database):
  150. async with database.testing_transaction() as test_transaction:
  151. yield test_transaction
  152. @pytest.fixture
  153. def sql_gateway(test_transaction):
  154. return TstSQLGateway(test_transaction)
  155. @pytest.fixture
  156. def obj():
  157. return {
  158. "t": "foo",
  159. "f": 1.23,
  160. "b": True,
  161. "updated_at": datetime(2016, 6, 23, 2, 10, 25, tzinfo=timezone.utc),
  162. "n": None,
  163. }
  164. @pytest.fixture
  165. async def obj_in_db(test_transaction, obj):
  166. res = await test_transaction.execute(
  167. text(
  168. "INSERT INTO test_model (t, f, b, updated_at) "
  169. "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  170. "RETURNING id"
  171. )
  172. )
  173. return {"id": res[0]["id"], **obj}
  174. async def test_get(sql_gateway, obj_in_db):
  175. actual = await sql_gateway.get(obj_in_db["id"])
  176. assert isinstance(actual, dict)
  177. assert actual == obj_in_db
  178. assert actual is not obj_in_db
  179. async def test_get_not_found(sql_gateway, obj_in_db):
  180. assert await sql_gateway.get(obj_in_db["id"] + 1) is None
  181. async def test_add(sql_gateway, test_transaction, obj):
  182. created = await sql_gateway.add(obj)
  183. id = created.pop("id")
  184. assert isinstance(id, int)
  185. assert created is not obj
  186. assert created == obj
  187. res = await test_transaction.execute(
  188. text(f"SELECT * FROM test_model WHERE id = {id}")
  189. )
  190. assert res[0]["t"] == obj["t"]
  191. async def test_add_id_exists(sql_gateway, obj_in_db):
  192. with pytest.raises(
  193. AlreadyExists, match=f"record with id={obj_in_db['id']} already exists"
  194. ):
  195. await sql_gateway.add(obj_in_db)
  196. @pytest.mark.parametrize("id", [10, None, "delete"])
  197. async def test_add_integrity_error(sql_gateway, obj, id):
  198. obj.pop("t") # will cause the IntegrityError
  199. if id != "delete":
  200. obj["id"] = id
  201. with pytest.raises(NotNullViolationError):
  202. await sql_gateway.add(obj)
  203. async def test_add_unkown_column(sql_gateway, obj):
  204. created = await sql_gateway.add({"unknown": "foo", **obj})
  205. created.pop("id")
  206. assert created == obj
  207. async def test_update(sql_gateway, test_transaction, obj_in_db):
  208. obj_in_db["t"] = "bar"
  209. updated = await sql_gateway.update(obj_in_db)
  210. assert updated is not obj_in_db
  211. assert updated == obj_in_db
  212. res = await test_transaction.execute(
  213. text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
  214. )
  215. assert res[0]["t"] == "bar"
  216. async def test_update_not_found(sql_gateway, obj):
  217. obj["id"] = 42
  218. with pytest.raises(DoesNotExist):
  219. await sql_gateway.update(obj)
  220. async def test_update_unkown_column(sql_gateway, obj_in_db):
  221. obj_in_db["t"] = "bar"
  222. updated = await sql_gateway.update({"unknown": "foo", **obj_in_db})
  223. assert updated == obj_in_db
  224. async def test_upsert_does_add(sql_gateway, test_transaction, obj):
  225. obj["id"] = 42
  226. created = await sql_gateway.upsert(obj)
  227. assert created is not obj
  228. assert created == obj
  229. res = await test_transaction.execute(text("SELECT * FROM test_model WHERE id = 42"))
  230. assert res[0]["t"] == obj["t"]
  231. async def test_upsert_does_update(sql_gateway, test_transaction, obj_in_db):
  232. obj_in_db["t"] = "bar"
  233. updated = await sql_gateway.upsert(obj_in_db)
  234. assert updated is not obj_in_db
  235. assert updated == obj_in_db
  236. res = await test_transaction.execute(
  237. text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
  238. )
  239. assert res[0]["t"] == "bar"
  240. async def test_upsert_no_id(sql_gateway, test_transaction, obj):
  241. with mock.patch.object(sql_gateway, "add", new_callable=mock.AsyncMock) as add_m:
  242. created = await sql_gateway.upsert(obj)
  243. add_m.assert_awaited_with(obj)
  244. assert created == add_m.return_value
  245. async def test_remove(sql_gateway, test_transaction, obj_in_db):
  246. assert await sql_gateway.remove(obj_in_db["id"])
  247. res = await test_transaction.execute(
  248. text(f"SELECT COUNT(*) FROM test_model WHERE id = {obj_in_db['id']}")
  249. )
  250. assert res[0]["count"] == 0
  251. async def test_remove_not_found(sql_gateway):
  252. assert not await sql_gateway.remove(42)
  253. async def test_update_if_unmodified_since(sql_gateway, obj_in_db):
  254. obj_in_db["t"] = "bar"
  255. updated = await sql_gateway.update(
  256. obj_in_db, if_unmodified_since=obj_in_db["updated_at"]
  257. )
  258. assert updated == obj_in_db
  259. @pytest.mark.parametrize(
  260. "if_unmodified_since", [datetime.now(timezone.utc), datetime(2010, 1, 1)]
  261. )
  262. async def test_update_if_unmodified_since_not_ok(
  263. sql_gateway, obj_in_db, if_unmodified_since
  264. ):
  265. obj_in_db["t"] = "bar"
  266. with pytest.raises(Conflict):
  267. await sql_gateway.update(obj_in_db, if_unmodified_since=if_unmodified_since)
  268. @pytest.mark.parametrize(
  269. "filters,match",
  270. [
  271. ([], True),
  272. ([Filter(field="t", values=["foo"])], True),
  273. ([Filter(field="t", values=["bar"])], False),
  274. ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.23])], True),
  275. ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.24])], False),
  276. ([Filter(field="nonexisting", values=["foo"])], False),
  277. ([Filter(field="t", values=[])], False),
  278. ([Filter(field="t", values=["foo", "bar"])], True),
  279. ],
  280. )
  281. async def test_filter(filters, match, sql_gateway, obj_in_db):
  282. actual = await sql_gateway.filter(filters)
  283. assert actual == ([obj_in_db] if match else [])
  284. @pytest.fixture
  285. async def obj2_in_db(test_transaction, obj):
  286. res = await test_transaction.execute(
  287. text(
  288. "INSERT INTO test_model (t, f, b, updated_at) "
  289. "VALUES ('bar', 1.24, TRUE, '2018-06-22 19:10:25-07') "
  290. "RETURNING id"
  291. )
  292. )
  293. return {"id": res[0]["id"], **obj}
  294. @pytest.mark.parametrize(
  295. "filters,expected",
  296. [
  297. ([], 2),
  298. ([Filter(field="t", values=["foo"])], 1),
  299. ([Filter(field="t", values=["bar"])], 1),
  300. ([Filter(field="t", values=["baz"])], 0),
  301. ],
  302. )
  303. async def test_count(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
  304. actual = await sql_gateway.count(filters)
  305. assert actual == expected
  306. @pytest.mark.parametrize(
  307. "filters,expected",
  308. [
  309. ([], True),
  310. ([Filter(field="t", values=["foo"])], True),
  311. ([Filter(field="t", values=["bar"])], True),
  312. ([Filter(field="t", values=["baz"])], False),
  313. ],
  314. )
  315. async def test_exists(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
  316. actual = await sql_gateway.exists(filters)
  317. assert actual == expected
  318. async def test_truncate(database: SQLDatabase, obj):
  319. gateway = TstSQLGateway(database)
  320. await gateway.add(obj)
  321. assert await gateway.exists([])
  322. await database.truncate_tables(["test_model"])
  323. assert not await gateway.exists([])