repository.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. from typing import Any
  4. from typing import Generic
  5. from typing import List
  6. from typing import Optional
  7. from typing import Type
  8. from typing import TypeVar
  9. from typing import Union
  10. from clean_python.base.domain.exceptions import DoesNotExist
  11. from clean_python.base.domain.pagination import Page
  12. from clean_python.base.domain.pagination import PageOptions
  13. from clean_python.base.domain.root_entity import RootEntity
  14. from clean_python.base.infrastructure.gateway import Filter
  15. from clean_python.base.infrastructure.gateway import Gateway
  16. from clean_python.base.infrastructure.gateway import Json
  17. T = TypeVar("T", bound=RootEntity)
  18. class Repository(Generic[T]):
  19. entity: Type[T]
  20. def __init__(self, gateway: Gateway):
  21. self.gateway = gateway
  22. def __init_subclass__(cls) -> None:
  23. (base,) = cls.__orig_bases__ # type: ignore
  24. (entity,) = base.__args__
  25. assert issubclass(entity, RootEntity)
  26. super().__init_subclass__()
  27. cls.entity = entity
  28. async def all(self, params: Optional[PageOptions] = None) -> Page[T]:
  29. return await self.filter([], params=params)
  30. async def by(
  31. self, key: str, value: Any, params: Optional[PageOptions] = None
  32. ) -> Page[T]:
  33. return await self.filter([Filter(field=key, values=[value])], params=params)
  34. async def filter(
  35. self, filters: List[Filter], params: Optional[PageOptions] = None
  36. ) -> Page[T]:
  37. records = await self.gateway.filter(filters, params=params)
  38. total = len(records)
  39. # when using pagination, we may need to do a count in the db
  40. # except in a typical 'first page' situation with few records
  41. if params is not None and not (params.offset == 0 and total < params.limit):
  42. total = await self.count(filters)
  43. return Page(
  44. total=total,
  45. limit=params.limit if params else None,
  46. offset=params.offset if params else None,
  47. items=[self.entity(**x) for x in records],
  48. )
  49. async def get(self, id: int) -> T:
  50. res = await self.gateway.get(id)
  51. if res is None:
  52. raise DoesNotExist("object", id)
  53. else:
  54. return self.entity(**res)
  55. async def add(self, item: Union[T, Json]) -> T:
  56. if isinstance(item, dict):
  57. item = self.entity.create(**item)
  58. created = await self.gateway.add(item.dict())
  59. return self.entity(**created)
  60. async def update(self, id: int, values: Json) -> T:
  61. if not values:
  62. return await self.get(id)
  63. updated = await self.gateway.update_transactional(
  64. id, lambda x: self.entity(**x).update(**values).dict()
  65. )
  66. return self.entity(**updated)
  67. async def upsert(self, item: T) -> T:
  68. values = item.dict()
  69. upserted = await self.gateway.upsert(values)
  70. return self.entity(**upserted)
  71. async def remove(self, id: int) -> bool:
  72. return await self.gateway.remove(id)
  73. async def count(self, filters: List[Filter]) -> int:
  74. return await self.gateway.count(filters)
  75. async def exists(self, filters: List[Filter]) -> bool:
  76. return await self.gateway.exists(filters)