| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 | # -*- coding: utf-8 -*-# (c) Nelen & Schuurmansfrom copy import deepcopyfrom datetime import datetimefrom typing import Any, Callable, Dict, List, Optionalfrom .exceptions import AlreadyExists, Conflict, DoesNotExistfrom .pagination import PageOptionsfrom .value_object import ValueObject__all__ = ["Gateway", "Json", "Filter", "InMemoryGateway"]Json = Dict[str, Any]class Filter(ValueObject):    field: str    values: List[Any]class Gateway:    async def filter(        self, filters: List[Filter], params: Optional[PageOptions] = None    ) -> List[Json]:        raise NotImplementedError()    async def count(self, filters: List[Filter]) -> int:        return len(await self.filter(filters, params=None))    async def exists(self, filters: List[Filter]) -> bool:        return len(await self.filter(filters, params=PageOptions(limit=1))) > 0    async def get(self, id: int) -> Optional[Json]:        result = await self.filter([Filter(field="id", values=[id])], params=None)        return result[0] if result else None    async def add(self, item: Json) -> Json:        raise NotImplementedError()    async def update(        self, item: Json, if_unmodified_since: Optional[datetime] = None    ) -> Json:        raise NotImplementedError()    async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:        existing = await self.get(id)        if existing is None:            raise DoesNotExist("record", id)        return await self.update(            func(existing), if_unmodified_since=existing["updated_at"]        )    async def upsert(self, item: Json) -> Json:        try:            return await self.update(item)        except DoesNotExist:            return await self.add(item)    async def remove(self, id: int) -> bool:        raise NotImplementedError()class InMemoryGateway(Gateway):    """For testing purposes"""    def __init__(self, data: List[Json]):        self.data = {x["id"]: deepcopy(x) for x in data}    def _get_next_id(self) -> int:        if len(self.data) == 0:            return 1        else:            return max(self.data) + 1    def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:        objs = sorted(            objs,            key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),            reverse=not params.ascending,        )        return objs[params.offset : params.offset + params.limit]    async def filter(        self, filters: List[Filter], params: Optional[PageOptions] = None    ) -> List[Json]:        result = []        for x in self.data.values():            for filter in filters:                if x.get(filter.field) not in filter.values:                    break            else:                result.append(deepcopy(x))        if params is not None:            result = self._paginate(result, params)        return result    async def add(self, item: Json) -> Json:        item = item.copy()        id_ = item.pop("id", None)        # autoincrement (like SQL does)        if id_ is None:            id_ = self._get_next_id()        elif id_ in self.data:            raise AlreadyExists(id_)        self.data[id_] = {"id": id_, **item}        return deepcopy(self.data[id_])    async def update(        self, item: Json, if_unmodified_since: Optional[datetime] = None    ) -> Json:        _id = item.get("id")        if _id is None or _id not in self.data:            raise DoesNotExist("item", _id)        existing = self.data[_id]        if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:            raise Conflict()        existing.update(item)        return deepcopy(existing)    async def remove(self, id: int) -> bool:        if id not in self.data:            return False        del self.data[id]        return True
 |