test_repository.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from typing import List
  2. from unittest import mock
  3. import pytest
  4. from clean_python import BadRequest
  5. from clean_python import DoesNotExist
  6. from clean_python import Filter
  7. from clean_python import InMemoryGateway
  8. from clean_python import Page
  9. from clean_python import PageOptions
  10. from clean_python import Repository
  11. from clean_python import RootEntity
  12. class User(RootEntity):
  13. name: str
  14. @pytest.fixture
  15. def users():
  16. return [
  17. User.create(id=1, name="a"),
  18. User.create(id=2, name="b"),
  19. User.create(id=3, name="c"),
  20. ]
  21. class UserRepository(Repository[User]):
  22. pass
  23. @pytest.fixture
  24. def user_repository(users: List[User]):
  25. return UserRepository(gateway=InMemoryGateway(data=[x.model_dump() for x in users]))
  26. @pytest.fixture
  27. def page_options():
  28. return PageOptions(limit=10, offset=0, order_by="id")
  29. def test_entity_attr(user_repository):
  30. assert user_repository.entity is User
  31. async def test_get(user_repository):
  32. actual = await user_repository.get(1)
  33. assert actual.name == "a"
  34. async def test_get_does_not_exist(user_repository):
  35. with pytest.raises(DoesNotExist):
  36. await user_repository.get(4)
  37. @mock.patch.object(Repository, "filter")
  38. async def test_all(filter_m, user_repository, page_options):
  39. filter_m.return_value = Page(total=0, items=[])
  40. assert await user_repository.all(page_options) is filter_m.return_value
  41. filter_m.assert_awaited_once_with([], params=page_options)
  42. async def test_add(user_repository: UserRepository):
  43. actual = await user_repository.add(User.create(name="d"))
  44. assert actual.name == "d"
  45. assert user_repository.gateway.data[4] == actual.model_dump()
  46. async def test_add_json(user_repository: UserRepository):
  47. actual = await user_repository.add({"name": "d"})
  48. assert actual.name == "d"
  49. assert user_repository.gateway.data[4] == actual.model_dump()
  50. async def test_add_json_validates(user_repository: UserRepository):
  51. with pytest.raises(BadRequest):
  52. await user_repository.add({"id": "d"})
  53. async def test_update(user_repository: UserRepository):
  54. actual = await user_repository.update(id=2, values={"name": "d"})
  55. assert actual.name == "d"
  56. assert user_repository.gateway.data[2] == actual.model_dump()
  57. async def test_update_does_not_exist(user_repository: UserRepository):
  58. with pytest.raises(DoesNotExist):
  59. await user_repository.update(id=4, values={"name": "d"})
  60. async def test_update_validates(user_repository: UserRepository):
  61. with pytest.raises(BadRequest):
  62. await user_repository.update(id=2, values={"id": 6})
  63. async def test_remove(user_repository: UserRepository):
  64. assert await user_repository.remove(2)
  65. assert 2 not in user_repository.gateway.data
  66. async def test_remove_does_not_exist(user_repository: UserRepository):
  67. assert not await user_repository.remove(4)
  68. async def test_upsert_updates(user_repository: UserRepository):
  69. actual = await user_repository.upsert(User.create(id=2, name="d"))
  70. assert actual.name == "d"
  71. assert user_repository.gateway.data[2] == actual.model_dump()
  72. async def test_upsert_adds(user_repository: UserRepository):
  73. actual = await user_repository.upsert(User.create(id=4, name="d"))
  74. assert actual.name == "d"
  75. assert user_repository.gateway.data[4] == actual.model_dump()
  76. @mock.patch.object(InMemoryGateway, "count")
  77. async def test_filter(count_m, user_repository: UserRepository, users):
  78. actual = await user_repository.filter([Filter(field="name", values=["b"])])
  79. assert actual == Page(total=1, items=[users[1]], limit=None, offest=None)
  80. assert not count_m.called
  81. @mock.patch.object(InMemoryGateway, "count")
  82. async def test_filter_with_pagination(
  83. count_m, user_repository: UserRepository, users, page_options
  84. ):
  85. actual = await user_repository.filter(
  86. [Filter(field="name", values=["b"])], page_options
  87. )
  88. assert actual == Page(
  89. total=1, items=[users[1]], limit=page_options.limit, offset=page_options.offset
  90. )
  91. assert not count_m.called
  92. @pytest.mark.parametrize(
  93. "page_options",
  94. [
  95. PageOptions(limit=3, offset=0, order_by="id"),
  96. PageOptions(limit=10, offset=1, order_by="id"),
  97. ],
  98. )
  99. @mock.patch.object(InMemoryGateway, "count")
  100. async def test_filter_with_pagination_calls_count(
  101. count_m, user_repository: UserRepository, users, page_options
  102. ):
  103. count_m.return_value = 123
  104. actual = await user_repository.filter([], page_options)
  105. assert actual == Page(
  106. total=count_m.return_value,
  107. items=users[page_options.offset :],
  108. limit=page_options.limit,
  109. offset=page_options.offset,
  110. )
  111. assert count_m.called
  112. @mock.patch.object(Repository, "filter")
  113. async def test_by(filter_m, user_repository: UserRepository, page_options):
  114. filter_m.return_value = Page(total=0, items=[])
  115. assert await user_repository.by("name", "b", page_options) is filter_m.return_value
  116. filter_m.assert_awaited_once_with(
  117. [Filter(field="name", values=["b"])], params=page_options
  118. )
  119. @mock.patch.object(InMemoryGateway, "count")
  120. async def test_count(gateway_count, user_repository):
  121. assert await user_repository.count("foo") is gateway_count.return_value
  122. gateway_count.assert_awaited_once_with("foo")
  123. @mock.patch.object(InMemoryGateway, "exists")
  124. async def test_exists(gateway_exists, user_repository):
  125. assert await user_repository.exists("foo") is gateway_exists.return_value
  126. gateway_exists.assert_awaited_once_with("foo")