gateway.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. from copy import deepcopy
  4. from datetime import datetime
  5. from typing import Any, Callable, Dict, List, Optional
  6. from .exceptions import AlreadyExists, Conflict, DoesNotExist
  7. from .pagination import PageOptions
  8. from .value_object import ValueObject
  9. __all__ = ["Gateway", "Json", "Filter", "InMemoryGateway"]
  10. Json = Dict[str, Any]
  11. class Filter(ValueObject):
  12. field: str
  13. values: List[Any]
  14. class Gateway:
  15. async def filter(
  16. self, filters: List[Filter], params: Optional[PageOptions] = None
  17. ) -> List[Json]:
  18. raise NotImplementedError()
  19. async def count(self, filters: List[Filter]) -> int:
  20. return len(await self.filter(filters, params=None))
  21. async def exists(self, filters: List[Filter]) -> bool:
  22. return len(await self.filter(filters, params=PageOptions(limit=1))) > 0
  23. async def get(self, id: int) -> Optional[Json]:
  24. result = await self.filter([Filter(field="id", values=[id])], params=None)
  25. return result[0] if result else None
  26. async def add(self, item: Json) -> Json:
  27. raise NotImplementedError()
  28. async def update(
  29. self, item: Json, if_unmodified_since: Optional[datetime] = None
  30. ) -> Json:
  31. raise NotImplementedError()
  32. async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
  33. existing = await self.get(id)
  34. if existing is None:
  35. raise DoesNotExist("record", id)
  36. return await self.update(
  37. func(existing), if_unmodified_since=existing["updated_at"]
  38. )
  39. async def upsert(self, item: Json) -> Json:
  40. try:
  41. return await self.update(item)
  42. except DoesNotExist:
  43. return await self.add(item)
  44. async def remove(self, id: int) -> bool:
  45. raise NotImplementedError()
  46. class InMemoryGateway(Gateway):
  47. """For testing purposes"""
  48. def __init__(self, data: List[Json]):
  49. self.data = {x["id"]: deepcopy(x) for x in data}
  50. def _get_next_id(self) -> int:
  51. if len(self.data) == 0:
  52. return 1
  53. else:
  54. return max(self.data) + 1
  55. def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:
  56. objs = sorted(
  57. objs,
  58. key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),
  59. reverse=not params.ascending,
  60. )
  61. return objs[params.offset : params.offset + params.limit]
  62. async def filter(
  63. self, filters: List[Filter], params: Optional[PageOptions] = None
  64. ) -> List[Json]:
  65. result = []
  66. for x in self.data.values():
  67. for filter in filters:
  68. if x.get(filter.field) not in filter.values:
  69. break
  70. else:
  71. result.append(deepcopy(x))
  72. if params is not None:
  73. result = self._paginate(result, params)
  74. return result
  75. async def add(self, item: Json) -> Json:
  76. item = item.copy()
  77. id_ = item.pop("id", None)
  78. # autoincrement (like SQL does)
  79. if id_ is None:
  80. id_ = self._get_next_id()
  81. elif id_ in self.data:
  82. raise AlreadyExists(id_)
  83. self.data[id_] = {"id": id_, **item}
  84. return deepcopy(self.data[id_])
  85. async def update(
  86. self, item: Json, if_unmodified_since: Optional[datetime] = None
  87. ) -> Json:
  88. _id = item.get("id")
  89. if _id is None or _id not in self.data:
  90. raise DoesNotExist("item", _id)
  91. existing = self.data[_id]
  92. if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:
  93. raise Conflict()
  94. existing.update(item)
  95. return deepcopy(existing)
  96. async def remove(self, id: int) -> bool:
  97. if id not in self.data:
  98. return False
  99. del self.data[id]
  100. return True