123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- # (c) Nelen & Schuurmans
- from typing import Any
- from typing import Generic
- from typing import List
- from typing import Optional
- from typing import Type
- from typing import TypeVar
- from typing import Union
- from .exceptions import DoesNotExist
- from .filter import Filter
- from .gateway import Gateway
- from .gateway import SyncGateway
- from .pagination import Page
- from .pagination import PageOptions
- from .types import Id
- from .types import Json
- from .value_object import ValueObject
- __all__ = ["Repository", "SyncRepository"]
- T = TypeVar("T", bound=ValueObject)
- class Repository(Generic[T]):
- entity: Type[T]
- def __init__(self, gateway: Gateway):
- self.gateway = gateway
- def __init_subclass__(cls) -> None:
- (base,) = cls.__orig_bases__ # type: ignore
- (entity,) = base.__args__
- super().__init_subclass__()
- cls.entity = entity
- async def all(self, params: Optional[PageOptions] = None) -> Page[T]:
- return await self.filter([], params=params)
- async def by(
- self, key: str, value: Any, params: Optional[PageOptions] = None
- ) -> Page[T]:
- return await self.filter([Filter(field=key, values=[value])], params=params)
- async def filter(
- self, filters: List[Filter], params: Optional[PageOptions] = None
- ) -> Page[T]:
- records = await self.gateway.filter(filters, params=params)
- total = len(records)
- # when using pagination, we may need to do a count in the db
- # except in a typical 'first page' situation with few records
- if params is not None and not (params.offset == 0 and total < params.limit):
- total = await self.count(filters)
- return Page(
- total=total,
- limit=params.limit if params else None,
- offset=params.offset if params else None,
- items=[self.entity(**x) for x in records],
- )
- async def get(self, id: Id) -> T:
- res = await self.gateway.get(id)
- if res is None:
- raise DoesNotExist("object", id)
- else:
- return self.entity(**res)
- async def add(self, item: Union[T, Json]) -> T:
- if isinstance(item, dict):
- item = self.entity.create(**item)
- created = await self.gateway.add(item.model_dump())
- return self.entity(**created)
- async def update(self, id: Id, values: Json) -> T:
- if not values:
- return await self.get(id)
- updated = await self.gateway.update_transactional(
- id, lambda x: self.entity(**x).update(**values).model_dump()
- )
- return self.entity(**updated)
- async def upsert(self, item: T) -> T:
- values = item.model_dump()
- upserted = await self.gateway.upsert(values)
- return self.entity(**upserted)
- async def remove(self, id: Id) -> bool:
- return await self.gateway.remove(id)
- async def count(self, filters: List[Filter]) -> int:
- return await self.gateway.count(filters)
- async def exists(self, filters: List[Filter]) -> bool:
- return await self.gateway.exists(filters)
- # This is a copy-paste from Repository, but with all the async / await removed
- class SyncRepository(Generic[T]):
- entity: Type[T]
- def __init__(self, gateway: SyncGateway):
- self.gateway = gateway
- def __init_subclass__(cls) -> None:
- (base,) = cls.__orig_bases__ # type: ignore
- (entity,) = base.__args__
- super().__init_subclass__()
- cls.entity = entity
- def all(self, params: Optional[PageOptions] = None) -> Page[T]:
- return self.filter([], params=params)
- def by(self, key: str, value: Any, params: Optional[PageOptions] = None) -> Page[T]:
- return self.filter([Filter(field=key, values=[value])], params=params)
- def filter(
- self, filters: List[Filter], params: Optional[PageOptions] = None
- ) -> Page[T]:
- records = self.gateway.filter(filters, params=params)
- total = len(records)
- # when using pagination, we may need to do a count in the db
- # except in a typical 'first page' situation with few records
- if params is not None and not (params.offset == 0 and total < params.limit):
- total = self.count(filters)
- return Page(
- total=total,
- limit=params.limit if params else None,
- offset=params.offset if params else None,
- items=[self.entity(**x) for x in records],
- )
- def get(self, id: Id) -> T:
- res = self.gateway.get(id)
- if res is None:
- raise DoesNotExist("object", id)
- else:
- return self.entity(**res)
- def add(self, item: Union[T, Json]) -> T:
- if isinstance(item, dict):
- item = self.entity.create(**item)
- created = self.gateway.add(item.model_dump())
- return self.entity(**created)
- def update(self, id: Id, values: Json) -> T:
- if not values:
- return self.get(id)
- updated = self.gateway.update_transactional(
- id, lambda x: self.entity(**x).update(**values).model_dump()
- )
- return self.entity(**updated)
- def upsert(self, item: T) -> T:
- values = item.model_dump()
- upserted = self.gateway.upsert(values)
- return self.entity(**upserted)
- def remove(self, id: Id) -> bool:
- return self.gateway.remove(id)
- def count(self, filters: List[Filter]) -> int:
- return self.gateway.count(filters)
- def exists(self, filters: List[Filter]) -> bool:
- return self.gateway.exists(filters)
|