test_sql_database.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. import asyncio
  4. from datetime import datetime
  5. from datetime import timezone
  6. from unittest import mock
  7. import pytest
  8. from sqlalchemy import Boolean
  9. from sqlalchemy import Column
  10. from sqlalchemy import DateTime
  11. from sqlalchemy import Float
  12. from sqlalchemy import Integer
  13. from sqlalchemy import MetaData
  14. from sqlalchemy import Table
  15. from sqlalchemy import Text
  16. from sqlalchemy.exc import IntegrityError
  17. from sqlalchemy.pool import NullPool
  18. from sqlalchemy.sql import text
  19. from clean_python import AlreadyExists
  20. from clean_python import Conflict
  21. from clean_python import DoesNotExist
  22. from clean_python import Filter
  23. from clean_python.sql import SQLDatabase
  24. from clean_python.sql import SQLGateway
  25. test_model = Table(
  26. "test_model",
  27. MetaData(),
  28. Column("id", Integer, primary_key=True, autoincrement=True),
  29. Column("t", Text, nullable=False),
  30. Column("f", Float, nullable=False),
  31. Column("b", Boolean, nullable=False),
  32. Column("updated_at", DateTime(timezone=True), nullable=False),
  33. Column("n", Float, nullable=True),
  34. )
  35. ### SQLProvider integration tests
  36. count_query = text("SELECT COUNT(*) FROM test_model")
  37. insert_query = text(
  38. "INSERT INTO test_model (t, f, b, updated_at) "
  39. "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  40. "RETURNING id"
  41. )
  42. update_query = text("UPDATE test_model SET t='bar' WHERE id=:id RETURNING t")
  43. @pytest.fixture(scope="session")
  44. def dburl(postgres_url):
  45. return f"postgresql+asyncpg://{postgres_url}"
  46. @pytest.fixture(scope="session")
  47. def dbname():
  48. return "cleanpython_test"
  49. @pytest.fixture(scope="session")
  50. async def database(dburl, dbname):
  51. root_provider = SQLDatabase(f"{dburl}/")
  52. await root_provider.drop_database(dbname)
  53. await root_provider.create_database(dbname)
  54. provider = SQLDatabase(f"{dburl}/{dbname}")
  55. async with provider.engine.begin() as conn:
  56. await conn.run_sync(test_model.metadata.drop_all)
  57. await conn.run_sync(test_model.metadata.create_all)
  58. yield SQLDatabase(f"{dburl}/{dbname}")
  59. @pytest.fixture
  60. async def database_with_cleanup(database: SQLDatabase):
  61. await database.truncate_tables(["test_model"])
  62. yield database
  63. await database.truncate_tables(["test_model"])
  64. @pytest.fixture
  65. async def transaction_with_cleanup(database_with_cleanup):
  66. async with database_with_cleanup.transaction() as trans:
  67. yield trans
  68. @pytest.fixture
  69. async def record_id(database_with_cleanup: SQLDatabase) -> int:
  70. record = await database_with_cleanup.execute(insert_query)
  71. return record[0]["id"]
  72. async def test_execute(database_with_cleanup):
  73. db = database_with_cleanup
  74. await db.execute(insert_query)
  75. assert await db.execute(count_query) == [{"count": 1}]
  76. async def test_transaction_commits(database_with_cleanup):
  77. db = database_with_cleanup
  78. async with db.transaction() as trans:
  79. await trans.execute(insert_query)
  80. assert await db.execute(count_query) == [{"count": 1}]
  81. async def test_transaction_err(database_with_cleanup):
  82. db = database_with_cleanup
  83. await db.execute(insert_query)
  84. with pytest.raises(RuntimeError):
  85. async with db.transaction() as trans:
  86. await trans.execute(insert_query)
  87. raise RuntimeError() # triggers rollback
  88. assert await db.execute(count_query) == [{"count": 1}]
  89. async def test_nested_transaction_commits(transaction_with_cleanup):
  90. db = transaction_with_cleanup
  91. async with db.transaction() as trans:
  92. await trans.execute(insert_query)
  93. assert await db.execute(count_query) == [{"count": 1}]
  94. async def test_nested_transaction_err(transaction_with_cleanup):
  95. db = transaction_with_cleanup
  96. await db.execute(insert_query)
  97. with pytest.raises(RuntimeError):
  98. async with db.transaction() as trans:
  99. await trans.execute(insert_query)
  100. raise RuntimeError() # triggers rollback
  101. assert await db.execute(count_query) == [{"count": 1}]
  102. async def test_testing_transaction_rollback(database_with_cleanup):
  103. async with database_with_cleanup.testing_transaction() as trans:
  104. await trans.execute(insert_query)
  105. assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
  106. @pytest.fixture
  107. async def database_no_pool_no_cache(dburl, dbname):
  108. db = SQLDatabase(
  109. f"{dburl}/{dbname}?prepared_statement_cache_size=0", poolclass=NullPool
  110. )
  111. yield db
  112. await db.truncate_tables(["test_model"])
  113. async def test_handle_serialization_error(
  114. database_no_pool_no_cache: SQLDatabase, record_id: int
  115. ):
  116. """Typical 'lost update' situation will result in a Conflict error
  117. 1> BEGIN
  118. 1> UPDATE ... WHERE id=1
  119. 2> BEGIN
  120. 2> UPDATE ... WHERE id=1 # transaction 2 will wait until transaction 1 finishes
  121. 1> COMMIT
  122. 2> will raise SerializationError
  123. """
  124. async def update(sleep_before=0.0, sleep_after=0.0):
  125. await asyncio.sleep(sleep_before)
  126. async with database_no_pool_no_cache.transaction() as trans:
  127. res = await trans.execute(update_query, bind_params={"id": record_id})
  128. await asyncio.sleep(sleep_after)
  129. return res
  130. res1, res2 = await asyncio.gather(
  131. update(sleep_after=0.02), update(sleep_before=0.01), return_exceptions=True
  132. )
  133. assert res1 == [{"t": "bar"}]
  134. assert isinstance(res2, Conflict)
  135. assert str(res2) == "could not execute query due to concurrent update"
  136. async def test_handle_integrity_error(
  137. database_with_cleanup: SQLDatabase, record_id: int
  138. ):
  139. """Insert a record with an id that already exists"""
  140. insert_query_with_id = text(
  141. "INSERT INTO test_model (id, t, f, b, updated_at) "
  142. "VALUES (:id, 'foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  143. "RETURNING id"
  144. )
  145. with pytest.raises(
  146. AlreadyExists, match=f"record with id={record_id} already exists"
  147. ):
  148. await database_with_cleanup.execute(
  149. insert_query_with_id, bind_params={"id": record_id}
  150. )
  151. ### SQLGateway integration tests
  152. class TstSQLGateway(SQLGateway, table=test_model):
  153. pass
  154. @pytest.fixture
  155. async def test_transaction(database):
  156. async with database.testing_transaction() as test_transaction:
  157. yield test_transaction
  158. @pytest.fixture
  159. def sql_gateway(test_transaction):
  160. return TstSQLGateway(test_transaction)
  161. @pytest.fixture
  162. def obj():
  163. return {
  164. "t": "foo",
  165. "f": 1.23,
  166. "b": True,
  167. "updated_at": datetime(2016, 6, 23, 2, 10, 25, tzinfo=timezone.utc),
  168. "n": None,
  169. }
  170. @pytest.fixture
  171. async def obj_in_db(test_transaction, obj):
  172. res = await test_transaction.execute(
  173. text(
  174. "INSERT INTO test_model (t, f, b, updated_at) "
  175. "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  176. "RETURNING id"
  177. )
  178. )
  179. return {"id": res[0]["id"], **obj}
  180. async def test_get(sql_gateway, obj_in_db):
  181. actual = await sql_gateway.get(obj_in_db["id"])
  182. assert isinstance(actual, dict)
  183. assert actual == obj_in_db
  184. assert actual is not obj_in_db
  185. async def test_get_not_found(sql_gateway, obj_in_db):
  186. assert await sql_gateway.get(obj_in_db["id"] + 1) is None
  187. async def test_add(sql_gateway, test_transaction, obj):
  188. created = await sql_gateway.add(obj)
  189. id = created.pop("id")
  190. assert isinstance(id, int)
  191. assert created is not obj
  192. assert created == obj
  193. res = await test_transaction.execute(
  194. text(f"SELECT * FROM test_model WHERE id = {id}")
  195. )
  196. assert res[0]["t"] == obj["t"]
  197. async def test_add_id_exists(sql_gateway, obj_in_db):
  198. with pytest.raises(
  199. AlreadyExists, match=f"record with id={obj_in_db['id']} already exists"
  200. ):
  201. await sql_gateway.add(obj_in_db)
  202. @pytest.mark.parametrize("id", [10, None, "delete"])
  203. async def test_add_integrity_error(sql_gateway, obj, id):
  204. obj.pop("t") # will cause the IntegrityError
  205. if id != "delete":
  206. obj["id"] = id
  207. with pytest.raises(IntegrityError):
  208. await sql_gateway.add(obj)
  209. async def test_add_unkown_column(sql_gateway, obj):
  210. created = await sql_gateway.add({"unknown": "foo", **obj})
  211. created.pop("id")
  212. assert created == obj
  213. async def test_update(sql_gateway, test_transaction, obj_in_db):
  214. obj_in_db["t"] = "bar"
  215. updated = await sql_gateway.update(obj_in_db)
  216. assert updated is not obj_in_db
  217. assert updated == obj_in_db
  218. res = await test_transaction.execute(
  219. text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
  220. )
  221. assert res[0]["t"] == "bar"
  222. async def test_update_not_found(sql_gateway, obj):
  223. obj["id"] = 42
  224. with pytest.raises(DoesNotExist):
  225. await sql_gateway.update(obj)
  226. async def test_update_unkown_column(sql_gateway, obj_in_db):
  227. obj_in_db["t"] = "bar"
  228. updated = await sql_gateway.update({"unknown": "foo", **obj_in_db})
  229. assert updated == obj_in_db
  230. async def test_upsert_does_add(sql_gateway, test_transaction, obj):
  231. obj["id"] = 42
  232. created = await sql_gateway.upsert(obj)
  233. assert created is not obj
  234. assert created == obj
  235. res = await test_transaction.execute(text("SELECT * FROM test_model WHERE id = 42"))
  236. assert res[0]["t"] == obj["t"]
  237. async def test_upsert_does_update(sql_gateway, test_transaction, obj_in_db):
  238. obj_in_db["t"] = "bar"
  239. updated = await sql_gateway.upsert(obj_in_db)
  240. assert updated is not obj_in_db
  241. assert updated == obj_in_db
  242. res = await test_transaction.execute(
  243. text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
  244. )
  245. assert res[0]["t"] == "bar"
  246. async def test_upsert_no_id(sql_gateway, test_transaction, obj):
  247. with mock.patch.object(sql_gateway, "add", new_callable=mock.AsyncMock) as add_m:
  248. created = await sql_gateway.upsert(obj)
  249. add_m.assert_awaited_with(obj)
  250. assert created == add_m.return_value
  251. async def test_remove(sql_gateway, test_transaction, obj_in_db):
  252. assert await sql_gateway.remove(obj_in_db["id"])
  253. res = await test_transaction.execute(
  254. text(f"SELECT COUNT(*) FROM test_model WHERE id = {obj_in_db['id']}")
  255. )
  256. assert res[0]["count"] == 0
  257. async def test_remove_not_found(sql_gateway):
  258. assert not await sql_gateway.remove(42)
  259. async def test_update_if_unmodified_since(sql_gateway, obj_in_db):
  260. obj_in_db["t"] = "bar"
  261. updated = await sql_gateway.update(
  262. obj_in_db, if_unmodified_since=obj_in_db["updated_at"]
  263. )
  264. assert updated == obj_in_db
  265. @pytest.mark.parametrize(
  266. "if_unmodified_since", [datetime.now(timezone.utc), datetime(2010, 1, 1)]
  267. )
  268. async def test_update_if_unmodified_since_not_ok(
  269. sql_gateway, obj_in_db, if_unmodified_since
  270. ):
  271. obj_in_db["t"] = "bar"
  272. with pytest.raises(Conflict):
  273. await sql_gateway.update(obj_in_db, if_unmodified_since=if_unmodified_since)
  274. @pytest.mark.parametrize(
  275. "filters,match",
  276. [
  277. ([], True),
  278. ([Filter(field="t", values=["foo"])], True),
  279. ([Filter(field="t", values=["bar"])], False),
  280. ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.23])], True),
  281. ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.24])], False),
  282. ([Filter(field="nonexisting", values=["foo"])], False),
  283. ([Filter(field="t", values=[])], False),
  284. ([Filter(field="t", values=["foo", "bar"])], True),
  285. ],
  286. )
  287. async def test_filter(filters, match, sql_gateway, obj_in_db):
  288. actual = await sql_gateway.filter(filters)
  289. assert actual == ([obj_in_db] if match else [])
  290. @pytest.fixture
  291. async def obj2_in_db(test_transaction, obj):
  292. res = await test_transaction.execute(
  293. text(
  294. "INSERT INTO test_model (t, f, b, updated_at) "
  295. "VALUES ('bar', 1.24, TRUE, '2018-06-22 19:10:25-07') "
  296. "RETURNING id"
  297. )
  298. )
  299. return {"id": res[0]["id"], **obj}
  300. @pytest.mark.parametrize(
  301. "filters,expected",
  302. [
  303. ([], 2),
  304. ([Filter(field="t", values=["foo"])], 1),
  305. ([Filter(field="t", values=["bar"])], 1),
  306. ([Filter(field="t", values=["baz"])], 0),
  307. ],
  308. )
  309. async def test_count(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
  310. actual = await sql_gateway.count(filters)
  311. assert actual == expected
  312. @pytest.mark.parametrize(
  313. "filters,expected",
  314. [
  315. ([], True),
  316. ([Filter(field="t", values=["foo"])], True),
  317. ([Filter(field="t", values=["bar"])], True),
  318. ([Filter(field="t", values=["baz"])], False),
  319. ],
  320. )
  321. async def test_exists(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
  322. actual = await sql_gateway.exists(filters)
  323. assert actual == expected
  324. async def test_truncate(database: SQLDatabase, obj):
  325. gateway = TstSQLGateway(database)
  326. await gateway.add(obj)
  327. assert await gateway.exists([])
  328. await database.truncate_tables(["test_model"])
  329. assert not await gateway.exists([])