repository.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. from typing import Any, Generic, List, Optional, Type, TypeVar, Union
  4. from .exceptions import DoesNotExist
  5. from .gateway import Filter, Gateway, Json
  6. from .pagination import Page, PageOptions
  7. from .root_entity import RootEntity
  8. T = TypeVar("T", bound=RootEntity)
  9. class Repository(Generic[T]):
  10. entity: Type[T]
  11. def __init__(self, gateway: Gateway):
  12. self.gateway = gateway
  13. def __init_subclass__(cls) -> None:
  14. (base,) = cls.__orig_bases__ # type: ignore
  15. (entity,) = base.__args__
  16. assert issubclass(entity, RootEntity)
  17. super().__init_subclass__()
  18. cls.entity = entity
  19. async def all(self, params: Optional[PageOptions] = None) -> Page[T]:
  20. return await self.filter([], params=params)
  21. async def by(
  22. self, key: str, value: Any, params: Optional[PageOptions] = None
  23. ) -> Page[T]:
  24. return await self.filter([Filter(field=key, values=[value])], params=params)
  25. async def filter(
  26. self, filters: List[Filter], params: Optional[PageOptions] = None
  27. ) -> Page[T]:
  28. records = await self.gateway.filter(filters, params=params)
  29. total = len(records)
  30. # when using pagination, we may need to do a count in the db
  31. # except in a typical 'first page' situation with few records
  32. if params is not None and not (params.offset == 0 and total < params.limit):
  33. total = await self.count(filters)
  34. return Page(
  35. total=total,
  36. limit=params.limit if params else None,
  37. offset=params.offset if params else None,
  38. items=[self.entity(**x) for x in records],
  39. )
  40. async def get(self, id: int) -> T:
  41. res = await self.gateway.get(id)
  42. if res is None:
  43. raise DoesNotExist("object", id)
  44. else:
  45. return self.entity(**res)
  46. async def add(self, item: Union[T, Json]) -> T:
  47. if isinstance(item, dict):
  48. item = self.entity.create(**item)
  49. created = await self.gateway.add(item.dict())
  50. return self.entity(**created)
  51. async def update(self, id: int, values: Json) -> T:
  52. if not values:
  53. return await self.get(id)
  54. updated = await self.gateway.update_transactional(
  55. id, lambda x: self.entity(**x).update(**values).dict()
  56. )
  57. return self.entity(**updated)
  58. async def upsert(self, item: T) -> T:
  59. values = item.dict()
  60. upserted = await self.gateway.upsert(values)
  61. return self.entity(**upserted)
  62. async def remove(self, id: int) -> bool:
  63. return await self.gateway.remove(id)
  64. async def count(self, filters: List[Filter]) -> int:
  65. return await self.gateway.count(filters)
  66. async def exists(self, filters: List[Filter]) -> bool:
  67. return await self.gateway.exists(filters)