in_memory_gateway.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # (c) Nelen & Schuurmans
  2. from copy import deepcopy
  3. from datetime import datetime
  4. from typing import List
  5. from typing import Optional
  6. from clean_python.base.domain import AlreadyExists
  7. from clean_python.base.domain import Conflict
  8. from clean_python.base.domain import DoesNotExist
  9. from clean_python.base.domain import Filter
  10. from clean_python.base.domain import Gateway
  11. from clean_python.base.domain import Id
  12. from clean_python.base.domain import Json
  13. from clean_python.base.domain import PageOptions
  14. from clean_python.base.domain import SyncGateway
  15. __all__ = ["InMemoryGateway", "InMemorySyncGateway"]
  16. class InMemoryGateway(Gateway):
  17. """For testing purposes"""
  18. def __init__(self, data: List[Json]):
  19. self.data = {x["id"]: deepcopy(x) for x in data}
  20. def _get_next_id(self) -> int:
  21. if len(self.data) == 0:
  22. return 1
  23. else:
  24. return max(self.data) + 1
  25. def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:
  26. objs = sorted(
  27. objs,
  28. key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),
  29. reverse=not params.ascending,
  30. )
  31. return objs[params.offset : params.offset + params.limit]
  32. async def filter(
  33. self, filters: List[Filter], params: Optional[PageOptions] = None
  34. ) -> List[Json]:
  35. result = []
  36. for x in self.data.values():
  37. for filter in filters:
  38. if x.get(filter.field) not in filter.values:
  39. break
  40. else:
  41. result.append(deepcopy(x))
  42. if params is not None:
  43. result = self._paginate(result, params)
  44. return result
  45. async def add(self, item: Json) -> Json:
  46. item = item.copy()
  47. id_ = item.pop("id", None)
  48. # autoincrement (like SQL does)
  49. if id_ is None:
  50. id_ = self._get_next_id()
  51. elif id_ in self.data:
  52. raise AlreadyExists(id_)
  53. self.data[id_] = {"id": id_, **item}
  54. return deepcopy(self.data[id_])
  55. async def update(
  56. self, item: Json, if_unmodified_since: Optional[datetime] = None
  57. ) -> Json:
  58. _id = item.get("id")
  59. if _id is None or _id not in self.data:
  60. raise DoesNotExist("item", _id)
  61. existing = self.data[_id]
  62. if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:
  63. raise Conflict()
  64. existing.update(item)
  65. return deepcopy(existing)
  66. async def remove(self, id: Id) -> bool:
  67. if id not in self.data:
  68. return False
  69. del self.data[id]
  70. return True
  71. # This is a copy-paste of InMemoryGateway:
  72. class InMemorySyncGateway(SyncGateway):
  73. """For testing purposes"""
  74. def __init__(self, data: List[Json]):
  75. self.data = {x["id"]: deepcopy(x) for x in data}
  76. def _get_next_id(self) -> int:
  77. if len(self.data) == 0:
  78. return 1
  79. else:
  80. return max(self.data) + 1
  81. def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:
  82. objs = sorted(
  83. objs,
  84. key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),
  85. reverse=not params.ascending,
  86. )
  87. return objs[params.offset : params.offset + params.limit]
  88. def filter(
  89. self, filters: List[Filter], params: Optional[PageOptions] = None
  90. ) -> List[Json]:
  91. result = []
  92. for x in self.data.values():
  93. for filter in filters:
  94. if x.get(filter.field) not in filter.values:
  95. break
  96. else:
  97. result.append(deepcopy(x))
  98. if params is not None:
  99. result = self._paginate(result, params)
  100. return result
  101. def add(self, item: Json) -> Json:
  102. item = item.copy()
  103. id_ = item.pop("id", None)
  104. # autoincrement (like SQL does)
  105. if id_ is None:
  106. id_ = self._get_next_id()
  107. elif id_ in self.data:
  108. raise AlreadyExists(id_)
  109. self.data[id_] = {"id": id_, **item}
  110. return deepcopy(self.data[id_])
  111. def update(
  112. self, item: Json, if_unmodified_since: Optional[datetime] = None
  113. ) -> Json:
  114. _id = item.get("id")
  115. if _id is None or _id not in self.data:
  116. raise DoesNotExist("item", _id)
  117. existing = self.data[_id]
  118. if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:
  119. raise Conflict()
  120. existing.update(item)
  121. return deepcopy(existing)
  122. def remove(self, id: Id) -> bool:
  123. if id not in self.data:
  124. return False
  125. del self.data[id]
  126. return True