| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 | # (c) Nelen & Schuurmansfrom typing import Anyfrom typing import Genericfrom typing import Listfrom typing import Optionalfrom typing import Typefrom typing import TypeVarfrom typing import Unionfrom .exceptions import DoesNotExistfrom .filter import Filterfrom .gateway import Gatewayfrom .gateway import SyncGatewayfrom .pagination import Pagefrom .pagination import PageOptionsfrom .types import Idfrom .types import Jsonfrom .value_object import ValueObject__all__ = ["Repository", "SyncRepository"]T = TypeVar("T", bound=ValueObject)class Repository(Generic[T]):    entity: Type[T]    def __init__(self, gateway: Gateway):        self.gateway = gateway    def __init_subclass__(cls) -> None:        (base,) = cls.__orig_bases__  # type: ignore        (entity,) = base.__args__        super().__init_subclass__()        cls.entity = entity    async def all(self, params: Optional[PageOptions] = None) -> Page[T]:        return await self.filter([], params=params)    async def by(        self, key: str, value: Any, params: Optional[PageOptions] = None    ) -> Page[T]:        return await self.filter([Filter(field=key, values=[value])], params=params)    async def filter(        self, filters: List[Filter], params: Optional[PageOptions] = None    ) -> Page[T]:        records = await self.gateway.filter(filters, params=params)        total = len(records)        # when using pagination, we may need to do a count in the db        # except in a typical 'first page' situation with few records        if params is not None and not (params.offset == 0 and total < params.limit):            total = await self.count(filters)        return Page(            total=total,            limit=params.limit if params else None,            offset=params.offset if params else None,            items=[self.entity(**x) for x in records],        )    async def get(self, id: Id) -> T:        res = await self.gateway.get(id)        if res is None:            raise DoesNotExist("object", id)        else:            return self.entity(**res)    async def add(self, item: Union[T, Json]) -> T:        if isinstance(item, dict):            item = self.entity.create(**item)        created = await self.gateway.add(item.model_dump())        return self.entity(**created)    async def update(self, id: Id, values: Json) -> T:        if not values:            return await self.get(id)        updated = await self.gateway.update_transactional(            id, lambda x: self.entity(**x).update(**values).model_dump()        )        return self.entity(**updated)    async def upsert(self, item: T) -> T:        values = item.model_dump()        upserted = await self.gateway.upsert(values)        return self.entity(**upserted)    async def remove(self, id: Id) -> bool:        return await self.gateway.remove(id)    async def count(self, filters: List[Filter]) -> int:        return await self.gateway.count(filters)    async def exists(self, filters: List[Filter]) -> bool:        return await self.gateway.exists(filters)# This is a copy-paste from Repository, but with all the async / await removedclass SyncRepository(Generic[T]):    entity: Type[T]    def __init__(self, gateway: SyncGateway):        self.gateway = gateway    def __init_subclass__(cls) -> None:        (base,) = cls.__orig_bases__  # type: ignore        (entity,) = base.__args__        super().__init_subclass__()        cls.entity = entity    def all(self, params: Optional[PageOptions] = None) -> Page[T]:        return self.filter([], params=params)    def by(self, key: str, value: Any, params: Optional[PageOptions] = None) -> Page[T]:        return self.filter([Filter(field=key, values=[value])], params=params)    def filter(        self, filters: List[Filter], params: Optional[PageOptions] = None    ) -> Page[T]:        records = self.gateway.filter(filters, params=params)        total = len(records)        # when using pagination, we may need to do a count in the db        # except in a typical 'first page' situation with few records        if params is not None and not (params.offset == 0 and total < params.limit):            total = self.count(filters)        return Page(            total=total,            limit=params.limit if params else None,            offset=params.offset if params else None,            items=[self.entity(**x) for x in records],        )    def get(self, id: Id) -> T:        res = self.gateway.get(id)        if res is None:            raise DoesNotExist("object", id)        else:            return self.entity(**res)    def add(self, item: Union[T, Json]) -> T:        if isinstance(item, dict):            item = self.entity.create(**item)        created = self.gateway.add(item.model_dump())        return self.entity(**created)    def update(self, id: Id, values: Json) -> T:        if not values:            return self.get(id)        updated = self.gateway.update_transactional(            id, lambda x: self.entity(**x).update(**values).model_dump()        )        return self.entity(**updated)    def upsert(self, item: T) -> T:        values = item.model_dump()        upserted = self.gateway.upsert(values)        return self.entity(**upserted)    def remove(self, id: Id) -> bool:        return self.gateway.remove(id)    def count(self, filters: List[Filter]) -> int:        return self.gateway.count(filters)    def exists(self, filters: List[Filter]) -> bool:        return self.gateway.exists(filters)
 |