Prechádzať zdrojové kódy

Add sync equivalents of gateway and repository

Casper van der Wel 1 rok pred
rodič
commit
aafba6651f

+ 1 - 1
CHANGES.md

@@ -4,7 +4,7 @@
 0.5.2 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Added `SyncGateway`, `SyncRepository`, and `InMemorySyncGateway`.
 
 
 0.5.1 (2023-09-25)

+ 44 - 1
clean_python/base/domain/gateway.py

@@ -12,7 +12,7 @@ from .pagination import PageOptions
 from .types import Id
 from .types import Json
 
-__all__ = ["Gateway"]
+__all__ = ["Gateway", "SyncGateway"]
 
 
 class Gateway(ABC):
@@ -55,3 +55,46 @@ class Gateway(ABC):
 
     async def remove(self, id: Id) -> bool:
         raise NotImplementedError()
+
+
+# This is a copy-paste from clean_python.Gateway, but with all the async / await removed
+
+
+class SyncGateway:
+    def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[Json]:
+        raise NotImplementedError()
+
+    def count(self, filters: List[Filter]) -> int:
+        return len(self.filter(filters, params=None))
+
+    def exists(self, filters: List[Filter]) -> bool:
+        return len(self.filter(filters, params=PageOptions(limit=1))) > 0
+
+    def get(self, id: Id) -> Optional[Json]:
+        result = self.filter([Filter(field="id", values=[id])], params=None)
+        return result[0] if result else None
+
+    def add(self, item: Json) -> Json:
+        raise NotImplementedError()
+
+    def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        raise NotImplementedError()
+
+    def update_transactional(self, id: Id, func: Callable[[Json], Json]) -> Json:
+        existing = self.get(id)
+        if existing is None:
+            raise DoesNotExist("record", id)
+        return self.update(func(existing), if_unmodified_since=existing["updated_at"])
+
+    def upsert(self, item: Json) -> Json:
+        try:
+            return self.update(item)
+        except DoesNotExist:
+            return self.add(item)
+
+    def remove(self, id: Id) -> bool:
+        raise NotImplementedError()

+ 76 - 1
clean_python/base/domain/repository.py

@@ -11,13 +11,14 @@ from typing import Union
 from .exceptions import DoesNotExist
 from .filter import Filter
 from .gateway import Gateway
+from .gateway import SyncGateway
 from .pagination import Page
 from .pagination import PageOptions
 from .root_entity import RootEntity
 from .types import Id
 from .types import Json
 
-__all__ = ["Repository"]
+__all__ = ["Repository", "SyncRepository"]
 
 T = TypeVar("T", bound=RootEntity)
 
@@ -93,3 +94,77 @@ class Repository(Generic[T]):
 
     async def exists(self, filters: List[Filter]) -> bool:
         return await self.gateway.exists(filters)
+
+
+# This is a copy-paste from Repository, but with all the async / await removed
+
+
+class SyncRepository(Generic[T]):
+    entity: Type[T]
+
+    def __init__(self, gateway: SyncGateway):
+        self.gateway = gateway
+
+    def __init_subclass__(cls) -> None:
+        (base,) = cls.__orig_bases__  # type: ignore
+        (entity,) = base.__args__
+        assert issubclass(entity, RootEntity)
+        super().__init_subclass__()
+        cls.entity = entity
+
+    def all(self, params: Optional[PageOptions] = None) -> Page[T]:
+        return self.filter([], params=params)
+
+    def by(self, key: str, value: Any, params: Optional[PageOptions] = None) -> Page[T]:
+        return self.filter([Filter(field=key, values=[value])], params=params)
+
+    def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> Page[T]:
+        records = self.gateway.filter(filters, params=params)
+        total = len(records)
+        # when using pagination, we may need to do a count in the db
+        # except in a typical 'first page' situation with few records
+        if params is not None and not (params.offset == 0 and total < params.limit):
+            total = self.count(filters)
+        return Page(
+            total=total,
+            limit=params.limit if params else None,
+            offset=params.offset if params else None,
+            items=[self.entity(**x) for x in records],
+        )
+
+    def get(self, id: Id) -> T:
+        res = self.gateway.get(id)
+        if res is None:
+            raise DoesNotExist("object", id)
+        else:
+            return self.entity(**res)
+
+    def add(self, item: Union[T, Json]) -> T:
+        if isinstance(item, dict):
+            item = self.entity.create(**item)
+        created = self.gateway.add(item.model_dump())
+        return self.entity(**created)
+
+    def update(self, id: Id, values: Json) -> T:
+        if not values:
+            return self.get(id)
+        updated = self.gateway.update_transactional(
+            id, lambda x: self.entity(**x).update(**values).model_dump()
+        )
+        return self.entity(**updated)
+
+    def upsert(self, item: T) -> T:
+        values = item.model_dump()
+        upserted = self.gateway.upsert(values)
+        return self.entity(**upserted)
+
+    def remove(self, id: Id) -> bool:
+        return self.gateway.remove(id)
+
+    def count(self, filters: List[Filter]) -> int:
+        return self.gateway.count(filters)
+
+    def exists(self, filters: List[Filter]) -> bool:
+        return self.gateway.exists(filters)

+ 70 - 1
clean_python/base/infrastructure/in_memory_gateway.py

@@ -13,8 +13,9 @@ from clean_python.base.domain import Gateway
 from clean_python.base.domain import Id
 from clean_python.base.domain import Json
 from clean_python.base.domain import PageOptions
+from clean_python.base.domain import SyncGateway
 
-__all__ = ["InMemoryGateway"]
+__all__ = ["InMemoryGateway", "InMemorySyncGateway"]
 
 
 class InMemoryGateway(Gateway):
@@ -80,3 +81,71 @@ class InMemoryGateway(Gateway):
             return False
         del self.data[id]
         return True
+
+
+# This is a copy-paste of InMemoryGateway:
+
+
+class InMemorySyncGateway(SyncGateway):
+    """For testing purposes"""
+
+    def __init__(self, data: List[Json]):
+        self.data = {x["id"]: deepcopy(x) for x in data}
+
+    def _get_next_id(self) -> int:
+        if len(self.data) == 0:
+            return 1
+        else:
+            return max(self.data) + 1
+
+    def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:
+        objs = sorted(
+            objs,
+            key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),
+            reverse=not params.ascending,
+        )
+        return objs[params.offset : params.offset + params.limit]
+
+    def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[Json]:
+        result = []
+        for x in self.data.values():
+            for filter in filters:
+                if x.get(filter.field) not in filter.values:
+                    break
+            else:
+                result.append(deepcopy(x))
+        if params is not None:
+            result = self._paginate(result, params)
+        return result
+
+    def add(self, item: Json) -> Json:
+        item = item.copy()
+        id_ = item.pop("id", None)
+        # autoincrement (like SQL does)
+        if id_ is None:
+            id_ = self._get_next_id()
+        elif id_ in self.data:
+            raise AlreadyExists(id_)
+
+        self.data[id_] = {"id": id_, **item}
+        return deepcopy(self.data[id_])
+
+    def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        _id = item.get("id")
+        if _id is None or _id not in self.data:
+            raise DoesNotExist("item", _id)
+        existing = self.data[_id]
+        if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:
+            raise Conflict()
+        existing.update(item)
+        return deepcopy(existing)
+
+    def remove(self, id: Id) -> bool:
+        if id not in self.data:
+            return False
+        del self.data[id]
+        return True

+ 172 - 0
tests/test_sync_gateway.py

@@ -0,0 +1,172 @@
+# This module is a copy-paste of test_gateway.py
+
+from datetime import datetime
+from datetime import timezone
+from unittest import mock
+
+import pytest
+
+from clean_python import AlreadyExists
+from clean_python import Conflict
+from clean_python import DoesNotExist
+from clean_python import Filter
+from clean_python import InMemorySyncGateway
+from clean_python import PageOptions
+
+
+@pytest.fixture
+def in_memory_gateway():
+    return InMemorySyncGateway(
+        data=[
+            {"id": 1, "name": "a"},
+            {"id": 2, "name": "b"},
+            {"id": 3, "name": "c"},
+        ]
+    )
+
+
+def test_get(in_memory_gateway):
+    actual = in_memory_gateway.get(1)
+    assert actual == in_memory_gateway.data[1]
+
+
+def test_get_none(in_memory_gateway):
+    actual = in_memory_gateway.get(4)
+    assert actual is None
+
+
+def test_add(in_memory_gateway):
+    record = {"id": 5, "name": "d"}
+    in_memory_gateway.add(record)
+    assert in_memory_gateway.data[5] == record
+
+
+def test_add_id_autoincrement(in_memory_gateway):
+    record = {"name": "d"}
+    in_memory_gateway.add(record)
+    assert in_memory_gateway.data[4] == {"id": 4, "name": "d"}
+
+
+def test_add_id_exists(in_memory_gateway):
+    with pytest.raises(AlreadyExists):
+        in_memory_gateway.add({"id": 3})
+
+
+def test_update(in_memory_gateway):
+    record = {"id": 3, "name": "d"}
+    in_memory_gateway.update(record)
+    assert in_memory_gateway.data[3] == record
+
+
+def test_update_no_id(in_memory_gateway):
+    with pytest.raises(DoesNotExist):
+        in_memory_gateway.update({"no": "id"})
+
+
+def test_update_does_not_exist(in_memory_gateway):
+    with pytest.raises(DoesNotExist):
+        in_memory_gateway.update({"id": 4})
+
+
+def test_upsert(in_memory_gateway):
+    record = {"id": 3, "name": "d"}
+    in_memory_gateway.upsert(record)
+    assert in_memory_gateway.data[3] == record
+
+
+def test_upsert_no_id(in_memory_gateway):
+    in_memory_gateway.upsert({"name": "x"})
+    assert in_memory_gateway.data[4] == {"id": 4, "name": "x"}
+
+
+def test_upsert_does_add(in_memory_gateway):
+    in_memory_gateway.upsert({"id": 4, "name": "x"})
+    assert in_memory_gateway.data[4] == {"id": 4, "name": "x"}
+
+
+def test_remove(in_memory_gateway):
+    assert in_memory_gateway.remove(1)
+    assert 1 not in in_memory_gateway.data
+    assert len(in_memory_gateway.data) == 2
+
+
+def test_remove_not_existing(in_memory_gateway):
+    assert not in_memory_gateway.remove(4)
+    assert len(in_memory_gateway.data) == 3
+
+
+def test_updated_if_unmodified_since(in_memory_gateway):
+    existing = {"id": 4, "name": "e", "updated_at": datetime.now(timezone.utc)}
+    new = {"id": 4, "name": "f", "updated_at": datetime.now(timezone.utc)}
+
+    in_memory_gateway.add(existing)
+
+    in_memory_gateway.update(new, if_unmodified_since=existing["updated_at"])
+    assert in_memory_gateway.data[4]["name"] == "f"
+
+
+@pytest.mark.parametrize(
+    "if_unmodified_since", [datetime.now(timezone.utc), datetime(2010, 1, 1)]
+)
+def test_update_if_unmodified_since_not_ok(in_memory_gateway, if_unmodified_since):
+    existing = {"id": 4, "name": "e", "updated_at": datetime.now(timezone.utc)}
+    new = {"id": 4, "name": "f", "updated_at": datetime.now(timezone.utc)}
+
+    in_memory_gateway.add(existing)
+    with pytest.raises(Conflict):
+        in_memory_gateway.update(new, if_unmodified_since=if_unmodified_since)
+
+
+def test_filter_all(in_memory_gateway):
+    actual = in_memory_gateway.filter([])
+    assert actual == sorted(in_memory_gateway.data.values(), key=lambda x: x["id"])
+
+
+def test_filter_all_with_params(in_memory_gateway):
+    actual = in_memory_gateway.filter(
+        [], params=PageOptions(limit=2, offset=1, order_by="id", ascending=False)
+    )
+    assert [x["id"] for x in actual] == [2, 1]
+
+
+def test_filter(in_memory_gateway):
+    actual = in_memory_gateway.filter([Filter(field="name", values=["b"])])
+    assert actual == [in_memory_gateway.data[2]]
+
+
+def test_count_all(in_memory_gateway):
+    actual = in_memory_gateway.count([])
+    assert actual == 3
+
+
+def test_count_with_filter(in_memory_gateway):
+    actual = in_memory_gateway.count([Filter(field="name", values=["b"])])
+    assert actual == 1
+
+
+@mock.patch.object(InMemorySyncGateway, "update")
+def test_update_transactional(update):
+    record = {"id": 3, "name": "d", "updated_at": datetime(2010, 1, 1)}
+    gateway = InMemorySyncGateway([record])
+    gateway.update_transactional(3, lambda x: {"name": x["name"] + "x"})
+
+    update.assert_called_once_with(
+        {"name": "dx"}, if_unmodified_since=datetime(2010, 1, 1)
+    )
+
+
+def test_update_transactional_does_not_exist(in_memory_gateway):
+    with pytest.raises(DoesNotExist):
+        in_memory_gateway.update_transactional(5, lambda x: x)
+
+
+def test_exists_all(in_memory_gateway):
+    assert in_memory_gateway.exists([])
+
+
+def test_exists_with_filter(in_memory_gateway):
+    assert in_memory_gateway.exists([Filter(field="name", values=["b"])])
+
+
+def test_exists_with_filter_not(in_memory_gateway):
+    assert not in_memory_gateway.exists([Filter(field="name", values=["bb"])])

+ 182 - 0
tests/test_sync_repository.py

@@ -0,0 +1,182 @@
+# This module is a copy paste of test_repository.py
+
+from typing import List
+from unittest import mock
+
+import pytest
+
+from clean_python import BadRequest
+from clean_python import DoesNotExist
+from clean_python import Filter
+from clean_python import InMemorySyncGateway
+from clean_python import Page
+from clean_python import PageOptions
+from clean_python import RootEntity
+from clean_python import SyncRepository
+
+
+class User(RootEntity):
+    name: str
+
+
+@pytest.fixture
+def users():
+    return [
+        User.create(id=1, name="a"),
+        User.create(id=2, name="b"),
+        User.create(id=3, name="c"),
+    ]
+
+
+class UserSyncRepository(SyncRepository[User]):
+    pass
+
+
+@pytest.fixture
+def user_repository(users: List[User]):
+    return UserSyncRepository(
+        gateway=InMemorySyncGateway(data=[x.model_dump() for x in users])
+    )
+
+
+@pytest.fixture
+def page_options():
+    return PageOptions(limit=10, offset=0, order_by="id")
+
+
+def test_entity_attr(user_repository):
+    assert user_repository.entity is User
+
+
+def test_get(user_repository):
+    actual = user_repository.get(1)
+    assert actual.name == "a"
+
+
+def test_get_does_not_exist(user_repository):
+    with pytest.raises(DoesNotExist):
+        user_repository.get(4)
+
+
+@mock.patch.object(SyncRepository, "filter")
+def test_all(filter_m, user_repository, page_options):
+    filter_m.return_value = Page(total=0, items=[])
+    assert user_repository.all(page_options) is filter_m.return_value
+
+    filter_m.assert_called_once_with([], params=page_options)
+
+
+def test_add(user_repository: UserSyncRepository):
+    actual = user_repository.add(User.create(name="d"))
+    assert actual.name == "d"
+    assert user_repository.gateway.data[4] == actual.model_dump()
+
+
+def test_add_json(user_repository: UserSyncRepository):
+    actual = user_repository.add({"name": "d"})
+    assert actual.name == "d"
+    assert user_repository.gateway.data[4] == actual.model_dump()
+
+
+def test_add_json_validates(user_repository: UserSyncRepository):
+    with pytest.raises(BadRequest):
+        user_repository.add({"id": "d"})
+
+
+def test_update(user_repository: UserSyncRepository):
+    actual = user_repository.update(id=2, values={"name": "d"})
+    assert actual.name == "d"
+    assert user_repository.gateway.data[2] == actual.model_dump()
+
+
+def test_update_does_not_exist(user_repository: UserSyncRepository):
+    with pytest.raises(DoesNotExist):
+        user_repository.update(id=4, values={"name": "d"})
+
+
+def test_update_validates(user_repository: UserSyncRepository):
+    with pytest.raises(BadRequest):
+        user_repository.update(id=2, values={"id": 6})
+
+
+def test_remove(user_repository: UserSyncRepository):
+    assert user_repository.remove(2)
+    assert 2 not in user_repository.gateway.data
+
+
+def test_remove_does_not_exist(user_repository: UserSyncRepository):
+    assert not user_repository.remove(4)
+
+
+def test_upsert_updates(user_repository: UserSyncRepository):
+    actual = user_repository.upsert(User.create(id=2, name="d"))
+    assert actual.name == "d"
+    assert user_repository.gateway.data[2] == actual.model_dump()
+
+
+def test_upsert_adds(user_repository: UserSyncRepository):
+    actual = user_repository.upsert(User.create(id=4, name="d"))
+    assert actual.name == "d"
+    assert user_repository.gateway.data[4] == actual.model_dump()
+
+
+@mock.patch.object(InMemorySyncGateway, "count")
+def test_filter(count_m, user_repository: UserSyncRepository, users):
+    actual = user_repository.filter([Filter(field="name", values=["b"])])
+    assert actual == Page(total=1, items=[users[1]], limit=None, offest=None)
+    assert not count_m.called
+
+
+@mock.patch.object(InMemorySyncGateway, "count")
+def test_filter_with_pagination(
+    count_m, user_repository: UserSyncRepository, users, page_options
+):
+    actual = user_repository.filter([Filter(field="name", values=["b"])], page_options)
+    assert actual == Page(
+        total=1, items=[users[1]], limit=page_options.limit, offset=page_options.offset
+    )
+    assert not count_m.called
+
+
+@pytest.mark.parametrize(
+    "page_options",
+    [
+        PageOptions(limit=3, offset=0, order_by="id"),
+        PageOptions(limit=10, offset=1, order_by="id"),
+    ],
+)
+@mock.patch.object(InMemorySyncGateway, "count")
+def test_filter_with_pagination_calls_count(
+    count_m, user_repository: UserSyncRepository, users, page_options
+):
+    count_m.return_value = 123
+    actual = user_repository.filter([], page_options)
+    assert actual == Page(
+        total=count_m.return_value,
+        items=users[page_options.offset :],
+        limit=page_options.limit,
+        offset=page_options.offset,
+    )
+    assert count_m.called
+
+
+@mock.patch.object(SyncRepository, "filter")
+def test_by(filter_m, user_repository: UserSyncRepository, page_options):
+    filter_m.return_value = Page(total=0, items=[])
+    assert user_repository.by("name", "b", page_options) is filter_m.return_value
+
+    filter_m.assert_called_once_with(
+        [Filter(field="name", values=["b"])], params=page_options
+    )
+
+
+@mock.patch.object(InMemorySyncGateway, "count")
+def test_count(gateway_count, user_repository):
+    assert user_repository.count("foo") is gateway_count.return_value
+    gateway_count.assert_called_once_with("foo")
+
+
+@mock.patch.object(InMemorySyncGateway, "exists")
+def test_exists(gateway_exists, user_repository):
+    assert user_repository.exists("foo") is gateway_exists.return_value
+    gateway_exists.assert_called_once_with("foo")