gateway.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. from copy import deepcopy
  4. from datetime import datetime
  5. from typing import Any
  6. from typing import Callable
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from clean_python.base.domain.exceptions import AlreadyExists
  11. from clean_python.base.domain.exceptions import Conflict
  12. from clean_python.base.domain.exceptions import DoesNotExist
  13. from clean_python.base.domain.pagination import PageOptions
  14. from clean_python.base.domain.value_object import ValueObject
  15. __all__ = ["Gateway", "Json", "Filter", "InMemoryGateway"]
  16. Json = Dict[str, Any]
  17. class Filter(ValueObject):
  18. field: str
  19. values: List[Any]
  20. class Gateway:
  21. async def filter(
  22. self, filters: List[Filter], params: Optional[PageOptions] = None
  23. ) -> List[Json]:
  24. raise NotImplementedError()
  25. async def count(self, filters: List[Filter]) -> int:
  26. return len(await self.filter(filters, params=None))
  27. async def exists(self, filters: List[Filter]) -> bool:
  28. return len(await self.filter(filters, params=PageOptions(limit=1))) > 0
  29. async def get(self, id: int) -> Optional[Json]:
  30. result = await self.filter([Filter(field="id", values=[id])], params=None)
  31. return result[0] if result else None
  32. async def add(self, item: Json) -> Json:
  33. raise NotImplementedError()
  34. async def update(
  35. self, item: Json, if_unmodified_since: Optional[datetime] = None
  36. ) -> Json:
  37. raise NotImplementedError()
  38. async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
  39. existing = await self.get(id)
  40. if existing is None:
  41. raise DoesNotExist("record", id)
  42. return await self.update(
  43. func(existing), if_unmodified_since=existing["updated_at"]
  44. )
  45. async def upsert(self, item: Json) -> Json:
  46. try:
  47. return await self.update(item)
  48. except DoesNotExist:
  49. return await self.add(item)
  50. async def remove(self, id: int) -> bool:
  51. raise NotImplementedError()
  52. class InMemoryGateway(Gateway):
  53. """For testing purposes"""
  54. def __init__(self, data: List[Json]):
  55. self.data = {x["id"]: deepcopy(x) for x in data}
  56. def _get_next_id(self) -> int:
  57. if len(self.data) == 0:
  58. return 1
  59. else:
  60. return max(self.data) + 1
  61. def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:
  62. objs = sorted(
  63. objs,
  64. key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),
  65. reverse=not params.ascending,
  66. )
  67. return objs[params.offset : params.offset + params.limit]
  68. async def filter(
  69. self, filters: List[Filter], params: Optional[PageOptions] = None
  70. ) -> List[Json]:
  71. result = []
  72. for x in self.data.values():
  73. for filter in filters:
  74. if x.get(filter.field) not in filter.values:
  75. break
  76. else:
  77. result.append(deepcopy(x))
  78. if params is not None:
  79. result = self._paginate(result, params)
  80. return result
  81. async def add(self, item: Json) -> Json:
  82. item = item.copy()
  83. id_ = item.pop("id", None)
  84. # autoincrement (like SQL does)
  85. if id_ is None:
  86. id_ = self._get_next_id()
  87. elif id_ in self.data:
  88. raise AlreadyExists(id_)
  89. self.data[id_] = {"id": id_, **item}
  90. return deepcopy(self.data[id_])
  91. async def update(
  92. self, item: Json, if_unmodified_since: Optional[datetime] = None
  93. ) -> Json:
  94. _id = item.get("id")
  95. if _id is None or _id not in self.data:
  96. raise DoesNotExist("item", _id)
  97. existing = self.data[_id]
  98. if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:
  99. raise Conflict()
  100. existing.update(item)
  101. return deepcopy(existing)
  102. async def remove(self, id: int) -> bool:
  103. if id not in self.data:
  104. return False
  105. del self.data[id]
  106. return True