repository.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. from typing import Any, Generic, List, Optional, Type, TypeVar, Union
  4. from clean_python.base.domain.exceptions import DoesNotExist
  5. from clean_python.base.infrastructure.gateway import Filter, Gateway, Json
  6. from clean_python.base.domain.pagination import Page, PageOptions
  7. from clean_python.base.domain.root_entity import RootEntity
  8. T = TypeVar("T", bound=RootEntity)
  9. class Repository(Generic[T]):
  10. entity: Type[T]
  11. def __init__(self, gateway: Gateway):
  12. self.gateway = gateway
  13. def __init_subclass__(cls) -> None:
  14. (base,) = cls.__orig_bases__ # type: ignore
  15. (entity,) = base.__args__
  16. assert issubclass(entity, RootEntity)
  17. super().__init_subclass__()
  18. cls.entity = entity
  19. async def all(self, params: Optional[PageOptions] = None) -> Page[T]:
  20. return await self.filter([], params=params)
  21. async def by(
  22. self, key: str, value: Any, params: Optional[PageOptions] = None
  23. ) -> Page[T]:
  24. return await self.filter([Filter(field=key, values=[value])], params=params)
  25. async def filter(
  26. self, filters: List[Filter], params: Optional[PageOptions] = None
  27. ) -> Page[T]:
  28. records = await self.gateway.filter(filters, params=params)
  29. total = len(records)
  30. # when using pagination, we may need to do a count in the db
  31. # except in a typical 'first page' situation with few records
  32. if params is not None and not (params.offset == 0 and total < params.limit):
  33. total = await self.count(filters)
  34. return Page(
  35. total=total,
  36. limit=params.limit if params else None,
  37. offset=params.offset if params else None,
  38. items=[self.entity(**x) for x in records],
  39. )
  40. async def get(self, id: int) -> T:
  41. res = await self.gateway.get(id)
  42. if res is None:
  43. raise DoesNotExist("object", id)
  44. else:
  45. return self.entity(**res)
  46. async def add(self, item: Union[T, Json]) -> T:
  47. if isinstance(item, dict):
  48. item = self.entity.create(**item)
  49. created = await self.gateway.add(item.dict())
  50. return self.entity(**created)
  51. async def update(self, id: int, values: Json) -> T:
  52. if not values:
  53. return await self.get(id)
  54. updated = await self.gateway.update_transactional(
  55. id, lambda x: self.entity(**x).update(**values).dict()
  56. )
  57. return self.entity(**updated)
  58. async def upsert(self, item: T) -> T:
  59. values = item.dict()
  60. upserted = await self.gateway.upsert(values)
  61. return self.entity(**upserted)
  62. async def remove(self, id: int) -> bool:
  63. return await self.gateway.remove(id)
  64. async def count(self, filters: List[Filter]) -> int:
  65. return await self.gateway.count(filters)
  66. async def exists(self, filters: List[Filter]) -> bool:
  67. return await self.gateway.exists(filters)