test_sql_database.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. from datetime import datetime
  4. from datetime import timezone
  5. from unittest import mock
  6. import pytest
  7. from sqlalchemy import Boolean
  8. from sqlalchemy import Column
  9. from sqlalchemy import DateTime
  10. from sqlalchemy import Float
  11. from sqlalchemy import Integer
  12. from sqlalchemy import MetaData
  13. from sqlalchemy import Table
  14. from sqlalchemy import Text
  15. from sqlalchemy.exc import IntegrityError
  16. from sqlalchemy.sql import text
  17. from clean_python import AlreadyExists
  18. from clean_python import Conflict
  19. from clean_python import DoesNotExist
  20. from clean_python import Filter
  21. from clean_python.sql import SQLDatabase
  22. from clean_python.sql import SQLGateway
  23. test_model = Table(
  24. "test_model",
  25. MetaData(),
  26. Column("id", Integer, primary_key=True, autoincrement=True),
  27. Column("t", Text, nullable=False),
  28. Column("f", Float, nullable=False),
  29. Column("b", Boolean, nullable=False),
  30. Column("updated_at", DateTime(timezone=True), nullable=False),
  31. Column("n", Float, nullable=True),
  32. )
  33. ### SQLProvider integration tests
  34. count_query = text("SELECT COUNT(*) FROM test_model")
  35. insert_query = text(
  36. "INSERT INTO test_model (t, f, b, updated_at) "
  37. "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  38. "RETURNING id"
  39. )
  40. @pytest.fixture(scope="session")
  41. async def database(postgres_url):
  42. dburl = f"postgresql+asyncpg://{postgres_url}"
  43. dbname = "cleanpython_test"
  44. root_provider = SQLDatabase(f"{dburl}/")
  45. await root_provider.drop_database(dbname)
  46. await root_provider.create_database(dbname)
  47. provider = SQLDatabase(f"{dburl}/{dbname}")
  48. async with provider.engine.begin() as conn:
  49. await conn.run_sync(test_model.metadata.drop_all)
  50. await conn.run_sync(test_model.metadata.create_all)
  51. yield SQLDatabase(f"{dburl}/{dbname}")
  52. @pytest.fixture
  53. async def database_with_cleanup(database: SQLDatabase):
  54. await database.truncate_tables(["test_model"])
  55. yield database
  56. await database.truncate_tables(["test_model"])
  57. @pytest.fixture
  58. async def transaction_with_cleanup(database_with_cleanup):
  59. async with database_with_cleanup.transaction() as trans:
  60. yield trans
  61. async def test_execute(database_with_cleanup):
  62. db = database_with_cleanup
  63. await db.execute(insert_query)
  64. assert await db.execute(count_query) == [{"count": 1}]
  65. async def test_transaction_commits(database_with_cleanup):
  66. db = database_with_cleanup
  67. async with db.transaction() as trans:
  68. await trans.execute(insert_query)
  69. assert await db.execute(count_query) == [{"count": 1}]
  70. async def test_transaction_err(database_with_cleanup):
  71. db = database_with_cleanup
  72. await db.execute(insert_query)
  73. with pytest.raises(RuntimeError):
  74. async with db.transaction() as trans:
  75. await trans.execute(insert_query)
  76. raise RuntimeError() # triggers rollback
  77. assert await db.execute(count_query) == [{"count": 1}]
  78. async def test_nested_transaction_commits(transaction_with_cleanup):
  79. db = transaction_with_cleanup
  80. async with db.transaction() as trans:
  81. await trans.execute(insert_query)
  82. assert await db.execute(count_query) == [{"count": 1}]
  83. async def test_nested_transaction_err(transaction_with_cleanup):
  84. db = transaction_with_cleanup
  85. await db.execute(insert_query)
  86. with pytest.raises(RuntimeError):
  87. async with db.transaction() as trans:
  88. await trans.execute(insert_query)
  89. raise RuntimeError() # triggers rollback
  90. assert await db.execute(count_query) == [{"count": 1}]
  91. async def test_testing_transaction_rollback(database_with_cleanup):
  92. async with database_with_cleanup.testing_transaction() as trans:
  93. await trans.execute(insert_query)
  94. assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
  95. ### SQLGateway integration tests
  96. class TstSQLGateway(SQLGateway, table=test_model):
  97. pass
  98. @pytest.fixture
  99. async def test_transaction(database):
  100. async with database.testing_transaction() as test_transaction:
  101. yield test_transaction
  102. @pytest.fixture
  103. def sql_gateway(test_transaction):
  104. return TstSQLGateway(test_transaction)
  105. @pytest.fixture
  106. def obj():
  107. return {
  108. "t": "foo",
  109. "f": 1.23,
  110. "b": True,
  111. "updated_at": datetime(2016, 6, 23, 2, 10, 25, tzinfo=timezone.utc),
  112. "n": None,
  113. }
  114. @pytest.fixture
  115. async def obj_in_db(test_transaction, obj):
  116. res = await test_transaction.execute(
  117. text(
  118. "INSERT INTO test_model (t, f, b, updated_at) "
  119. "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
  120. "RETURNING id"
  121. )
  122. )
  123. return {"id": res[0]["id"], **obj}
  124. async def test_get(sql_gateway, obj_in_db):
  125. actual = await sql_gateway.get(obj_in_db["id"])
  126. assert isinstance(actual, dict)
  127. assert actual == obj_in_db
  128. assert actual is not obj_in_db
  129. async def test_get_not_found(sql_gateway, obj_in_db):
  130. assert await sql_gateway.get(obj_in_db["id"] + 1) is None
  131. async def test_add(sql_gateway, test_transaction, obj):
  132. created = await sql_gateway.add(obj)
  133. id = created.pop("id")
  134. assert isinstance(id, int)
  135. assert created is not obj
  136. assert created == obj
  137. res = await test_transaction.execute(
  138. text(f"SELECT * FROM test_model WHERE id = {id}")
  139. )
  140. assert res[0]["t"] == obj["t"]
  141. async def test_add_id_exists(sql_gateway, obj_in_db):
  142. with pytest.raises(AlreadyExists):
  143. await sql_gateway.add(obj_in_db)
  144. @pytest.mark.parametrize("id", [10, None, "delete"])
  145. async def test_add_integrity_error(sql_gateway, obj, id):
  146. obj.pop("t") # will cause the IntegrityError
  147. if id != "delete":
  148. obj["id"] = id
  149. with pytest.raises(IntegrityError):
  150. await sql_gateway.add(obj)
  151. async def test_add_unkown_column(sql_gateway, obj):
  152. created = await sql_gateway.add({"unknown": "foo", **obj})
  153. created.pop("id")
  154. assert created == obj
  155. async def test_update(sql_gateway, test_transaction, obj_in_db):
  156. obj_in_db["t"] = "bar"
  157. updated = await sql_gateway.update(obj_in_db)
  158. assert updated is not obj_in_db
  159. assert updated == obj_in_db
  160. res = await test_transaction.execute(
  161. text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
  162. )
  163. assert res[0]["t"] == "bar"
  164. async def test_update_not_found(sql_gateway, obj):
  165. obj["id"] = 42
  166. with pytest.raises(DoesNotExist):
  167. await sql_gateway.update(obj)
  168. async def test_update_unkown_column(sql_gateway, obj_in_db):
  169. obj_in_db["t"] = "bar"
  170. updated = await sql_gateway.update({"unknown": "foo", **obj_in_db})
  171. assert updated == obj_in_db
  172. async def test_upsert_does_add(sql_gateway, test_transaction, obj):
  173. obj["id"] = 42
  174. created = await sql_gateway.upsert(obj)
  175. assert created is not obj
  176. assert created == obj
  177. res = await test_transaction.execute(text("SELECT * FROM test_model WHERE id = 42"))
  178. assert res[0]["t"] == obj["t"]
  179. async def test_upsert_does_update(sql_gateway, test_transaction, obj_in_db):
  180. obj_in_db["t"] = "bar"
  181. updated = await sql_gateway.upsert(obj_in_db)
  182. assert updated is not obj_in_db
  183. assert updated == obj_in_db
  184. res = await test_transaction.execute(
  185. text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
  186. )
  187. assert res[0]["t"] == "bar"
  188. async def test_upsert_no_id(sql_gateway, test_transaction, obj):
  189. with mock.patch.object(sql_gateway, "add", new_callable=mock.AsyncMock) as add_m:
  190. created = await sql_gateway.upsert(obj)
  191. add_m.assert_awaited_with(obj)
  192. assert created == add_m.return_value
  193. async def test_remove(sql_gateway, test_transaction, obj_in_db):
  194. assert await sql_gateway.remove(obj_in_db["id"])
  195. res = await test_transaction.execute(
  196. text(f"SELECT COUNT(*) FROM test_model WHERE id = {obj_in_db['id']}")
  197. )
  198. assert res[0]["count"] == 0
  199. async def test_remove_not_found(sql_gateway):
  200. assert not await sql_gateway.remove(42)
  201. async def test_update_if_unmodified_since(sql_gateway, obj_in_db):
  202. obj_in_db["t"] = "bar"
  203. updated = await sql_gateway.update(
  204. obj_in_db, if_unmodified_since=obj_in_db["updated_at"]
  205. )
  206. assert updated == obj_in_db
  207. @pytest.mark.parametrize(
  208. "if_unmodified_since", [datetime.now(timezone.utc), datetime(2010, 1, 1)]
  209. )
  210. async def test_update_if_unmodified_since_not_ok(
  211. sql_gateway, obj_in_db, if_unmodified_since
  212. ):
  213. obj_in_db["t"] = "bar"
  214. with pytest.raises(Conflict):
  215. await sql_gateway.update(obj_in_db, if_unmodified_since=if_unmodified_since)
  216. @pytest.mark.parametrize(
  217. "filters,match",
  218. [
  219. ([], True),
  220. ([Filter(field="t", values=["foo"])], True),
  221. ([Filter(field="t", values=["bar"])], False),
  222. ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.23])], True),
  223. ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.24])], False),
  224. ([Filter(field="nonexisting", values=["foo"])], False),
  225. ([Filter(field="t", values=[])], False),
  226. ([Filter(field="t", values=["foo", "bar"])], True),
  227. ],
  228. )
  229. async def test_filter(filters, match, sql_gateway, obj_in_db):
  230. actual = await sql_gateway.filter(filters)
  231. assert actual == ([obj_in_db] if match else [])
  232. @pytest.fixture
  233. async def obj2_in_db(test_transaction, obj):
  234. res = await test_transaction.execute(
  235. text(
  236. "INSERT INTO test_model (t, f, b, updated_at) "
  237. "VALUES ('bar', 1.24, TRUE, '2018-06-22 19:10:25-07') "
  238. "RETURNING id"
  239. )
  240. )
  241. return {"id": res[0]["id"], **obj}
  242. @pytest.mark.parametrize(
  243. "filters,expected",
  244. [
  245. ([], 2),
  246. ([Filter(field="t", values=["foo"])], 1),
  247. ([Filter(field="t", values=["bar"])], 1),
  248. ([Filter(field="t", values=["baz"])], 0),
  249. ],
  250. )
  251. async def test_count(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
  252. actual = await sql_gateway.count(filters)
  253. assert actual == expected
  254. @pytest.mark.parametrize(
  255. "filters,expected",
  256. [
  257. ([], True),
  258. ([Filter(field="t", values=["foo"])], True),
  259. ([Filter(field="t", values=["bar"])], True),
  260. ([Filter(field="t", values=["baz"])], False),
  261. ],
  262. )
  263. async def test_exists(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
  264. actual = await sql_gateway.exists(filters)
  265. assert actual == expected
  266. async def test_truncate(database: SQLDatabase, obj):
  267. gateway = TstSQLGateway(database)
  268. await gateway.add(obj)
  269. assert await gateway.exists([])
  270. await database.truncate_tables(["test_model"])
  271. assert not await gateway.exists([])