repository.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # (c) Nelen & Schuurmans
  2. from typing import Any
  3. from typing import Generic
  4. from typing import List
  5. from typing import Optional
  6. from typing import Type
  7. from typing import TypeVar
  8. from typing import Union
  9. from .exceptions import DoesNotExist
  10. from .filter import Filter
  11. from .gateway import Gateway
  12. from .gateway import SyncGateway
  13. from .pagination import Page
  14. from .pagination import PageOptions
  15. from .root_entity import RootEntity
  16. from .types import Id
  17. from .types import Json
  18. __all__ = ["Repository", "SyncRepository"]
  19. T = TypeVar("T", bound=RootEntity)
  20. class Repository(Generic[T]):
  21. entity: Type[T]
  22. def __init__(self, gateway: Gateway):
  23. self.gateway = gateway
  24. def __init_subclass__(cls) -> None:
  25. (base,) = cls.__orig_bases__ # type: ignore
  26. (entity,) = base.__args__
  27. assert issubclass(entity, RootEntity)
  28. super().__init_subclass__()
  29. cls.entity = entity
  30. async def all(self, params: Optional[PageOptions] = None) -> Page[T]:
  31. return await self.filter([], params=params)
  32. async def by(
  33. self, key: str, value: Any, params: Optional[PageOptions] = None
  34. ) -> Page[T]:
  35. return await self.filter([Filter(field=key, values=[value])], params=params)
  36. async def filter(
  37. self, filters: List[Filter], params: Optional[PageOptions] = None
  38. ) -> Page[T]:
  39. records = await self.gateway.filter(filters, params=params)
  40. total = len(records)
  41. # when using pagination, we may need to do a count in the db
  42. # except in a typical 'first page' situation with few records
  43. if params is not None and not (params.offset == 0 and total < params.limit):
  44. total = await self.count(filters)
  45. return Page(
  46. total=total,
  47. limit=params.limit if params else None,
  48. offset=params.offset if params else None,
  49. items=[self.entity(**x) for x in records],
  50. )
  51. async def get(self, id: Id) -> T:
  52. res = await self.gateway.get(id)
  53. if res is None:
  54. raise DoesNotExist("object", id)
  55. else:
  56. return self.entity(**res)
  57. async def add(self, item: Union[T, Json]) -> T:
  58. if isinstance(item, dict):
  59. item = self.entity.create(**item)
  60. created = await self.gateway.add(item.model_dump())
  61. return self.entity(**created)
  62. async def update(self, id: Id, values: Json) -> T:
  63. if not values:
  64. return await self.get(id)
  65. updated = await self.gateway.update_transactional(
  66. id, lambda x: self.entity(**x).update(**values).model_dump()
  67. )
  68. return self.entity(**updated)
  69. async def upsert(self, item: T) -> T:
  70. values = item.model_dump()
  71. upserted = await self.gateway.upsert(values)
  72. return self.entity(**upserted)
  73. async def remove(self, id: Id) -> bool:
  74. return await self.gateway.remove(id)
  75. async def count(self, filters: List[Filter]) -> int:
  76. return await self.gateway.count(filters)
  77. async def exists(self, filters: List[Filter]) -> bool:
  78. return await self.gateway.exists(filters)
  79. # This is a copy-paste from Repository, but with all the async / await removed
  80. class SyncRepository(Generic[T]):
  81. entity: Type[T]
  82. def __init__(self, gateway: SyncGateway):
  83. self.gateway = gateway
  84. def __init_subclass__(cls) -> None:
  85. (base,) = cls.__orig_bases__ # type: ignore
  86. (entity,) = base.__args__
  87. assert issubclass(entity, RootEntity)
  88. super().__init_subclass__()
  89. cls.entity = entity
  90. def all(self, params: Optional[PageOptions] = None) -> Page[T]:
  91. return self.filter([], params=params)
  92. def by(self, key: str, value: Any, params: Optional[PageOptions] = None) -> Page[T]:
  93. return self.filter([Filter(field=key, values=[value])], params=params)
  94. def filter(
  95. self, filters: List[Filter], params: Optional[PageOptions] = None
  96. ) -> Page[T]:
  97. records = self.gateway.filter(filters, params=params)
  98. total = len(records)
  99. # when using pagination, we may need to do a count in the db
  100. # except in a typical 'first page' situation with few records
  101. if params is not None and not (params.offset == 0 and total < params.limit):
  102. total = self.count(filters)
  103. return Page(
  104. total=total,
  105. limit=params.limit if params else None,
  106. offset=params.offset if params else None,
  107. items=[self.entity(**x) for x in records],
  108. )
  109. def get(self, id: Id) -> T:
  110. res = self.gateway.get(id)
  111. if res is None:
  112. raise DoesNotExist("object", id)
  113. else:
  114. return self.entity(**res)
  115. def add(self, item: Union[T, Json]) -> T:
  116. if isinstance(item, dict):
  117. item = self.entity.create(**item)
  118. created = self.gateway.add(item.model_dump())
  119. return self.entity(**created)
  120. def update(self, id: Id, values: Json) -> T:
  121. if not values:
  122. return self.get(id)
  123. updated = self.gateway.update_transactional(
  124. id, lambda x: self.entity(**x).update(**values).model_dump()
  125. )
  126. return self.entity(**updated)
  127. def upsert(self, item: T) -> T:
  128. values = item.model_dump()
  129. upserted = self.gateway.upsert(values)
  130. return self.entity(**upserted)
  131. def remove(self, id: Id) -> bool:
  132. return self.gateway.remove(id)
  133. def count(self, filters: List[Filter]) -> int:
  134. return self.gateway.count(filters)
  135. def exists(self, filters: List[Filter]) -> bool:
  136. return self.gateway.exists(filters)