Prechádzať zdrojové kódy

Add tests for ApiGateway and ApiProvider (#17)

Casper van der Wel 1 rok pred
rodič
commit
7aac476361

+ 36 - 2
clean_python/api_client/api_gateway.py

@@ -1,12 +1,16 @@
+from datetime import datetime
+from http import HTTPStatus
 from typing import Optional
 
 import inject
 
+from clean_python import DoesNotExist
 from clean_python import Id
 from clean_python import Json
 
 from .. import SyncGateway
 from .api_provider import SyncApiProvider
+from .exceptions import ApiException
 
 __all__ = ["SyncApiGateway"]
 
@@ -28,7 +32,12 @@ class SyncApiGateway(SyncGateway):
         return self.provider_override or inject.instance(SyncApiProvider)
 
     def get(self, id: Id) -> Optional[Json]:
-        return self.provider.request("GET", self.path.format(id=id))
+        try:
+            return self.provider.request("GET", self.path.format(id=id))
+        except ApiException as e:
+            if e.status is HTTPStatus.NOT_FOUND:
+                return None
+            raise e
 
     def add(self, item: Json) -> Json:
         result = self.provider.request("POST", self.path.format(id=""), json=item)
@@ -36,4 +45,29 @@ class SyncApiGateway(SyncGateway):
         return result
 
     def remove(self, id: Id) -> bool:
-        return self.provider.request("DELETE", self.path.format(id=id)) is not None
+        try:
+            self.provider.request("DELETE", self.path.format(id=id)) is not None
+        except ApiException as e:
+            if e.status is HTTPStatus.NOT_FOUND:
+                return False
+            raise e
+        else:
+            return True
+
+    def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        if if_unmodified_since is not None:
+            raise NotImplementedError("if_unmodified_since not implemented")
+        item = item.copy()
+        id_ = item.pop("id", None)
+        if id_ is None:
+            raise DoesNotExist("resource", id_)
+        try:
+            result = self.provider.request("PATCH", self.path.format(id=id_), json=item)
+            assert result is not None
+            return result
+        except ApiException as e:
+            if e.status is HTTPStatus.NOT_FOUND:
+                raise DoesNotExist("resource", id_)
+            raise e

+ 31 - 17
clean_python/api_client/api_provider.py

@@ -1,3 +1,5 @@
+import json as json_lib
+import re
 from http import HTTPStatus
 from typing import Callable
 from typing import Optional
@@ -21,6 +23,15 @@ def is_success(status: HTTPStatus) -> bool:
     return (int(status) // 100) == 2
 
 
+JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")
+
+
+def is_json_content_type(content_type: Optional[str]) -> bool:
+    if not content_type:
+        return False
+    return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))
+
+
 def join(url: str, path: str) -> str:
     """Results in a full url without trailing slash"""
     assert url.endswith("/")
@@ -71,31 +82,34 @@ class SyncApiProvider:
         timeout: float = 5.0,
     ) -> Optional[Json]:
         assert ctx.tenant is not None
-        url = join(self._url, path)
-        token = self._fetch_token(self._pool, ctx.tenant.id)
         headers = {}
+        request_kwargs = {
+            "method": method,
+            "url": add_query_params(join(self._url, path), params),
+            "timeout": timeout,
+        }
+        # for urllib3<2, we dump json ourselves
+        if json is not None and fields is not None:
+            raise ValueError("Cannot both specify 'json' and 'fields'")
+        elif json is not None:
+            request_kwargs["body"] = json_lib.dumps(json).encode()
+            headers["Content-Type"] = "application/json"
+        elif fields is not None:
+            request_kwargs["fields"] = fields
+        token = self._fetch_token(self._pool, ctx.tenant.id)
         if token is not None:
             headers["Authorization"] = f"Bearer {token}"
-        response = self._pool.request(
-            method=method,
-            url=add_query_params(url, params),
-            json=json,
-            fields=fields,
-            headers=headers,
-            timeout=timeout,
-        )
+        response = self._pool.request(headers=headers, **request_kwargs)
         status = HTTPStatus(response.status)
         content_type = response.headers.get("Content-Type")
-        if content_type is None and status is HTTPStatus.NO_CONTENT:
-            return {"status": int(status)}  # we have to return something...
-        if content_type != "application/json":
+        if status is HTTPStatus.NO_CONTENT:
+            return None
+        if not is_json_content_type(content_type):
             raise ApiException(
                 f"Unexpected content type '{content_type}'", status=status
             )
-        body = response.json()
-        if status is HTTPStatus.NOT_FOUND:
-            return None
-        elif is_success(status):
+        body = json_lib.loads(response.data.decode())
+        if is_success(status):
             return body
         else:
             raise ApiException(body, status=status)

+ 3 - 0
clean_python/api_client/exceptions.py

@@ -8,3 +8,6 @@ class ApiException(ValueError):
     def __init__(self, obj: Any, status: HTTPStatus):
         self.status = status
         super().__init__(obj)
+
+    def __str__(self):
+        return f"{self.status}: {super().__str__()}"

+ 12 - 0
integration_tests/conftest.py

@@ -1,9 +1,11 @@
 # (c) Nelen & Schuurmans
 
 import asyncio
+import multiprocessing
 import os
 
 import pytest
+import uvicorn
 
 
 def pytest_sessionstart(session):
@@ -39,3 +41,13 @@ async def postgres_url():
 @pytest.fixture(scope="session")
 async def s3_url():
     return os.environ.get("S3_URL", "http://localhost:9000")
+
+
+@pytest.fixture(scope="session")
+async def fastapi_example_app():
+    port = int(os.environ.get("API_PORT", "8005"))
+    config = uvicorn.Config("fastapi_example:app", host="0.0.0.0", port=port)
+    p = multiprocessing.Process(target=uvicorn.Server(config).run)
+    p.start()
+    yield f"http://localhost:{port}"
+    p.terminate()

+ 13 - 0
integration_tests/fastapi_example/__init__.py

@@ -0,0 +1,13 @@
+from clean_python import InMemoryGateway
+from clean_python.fastapi import Service
+
+from .presentation import V1Books
+
+service = Service(V1Books())
+
+app = service.create_app(
+    title="Book service",
+    description="Service for testing clean-python",
+    hostname="testserver",
+    access_logger_gateway=InMemoryGateway([]),
+)

+ 14 - 0
integration_tests/fastapi_example/application.py

@@ -0,0 +1,14 @@
+from typing import Optional
+
+from clean_python import InMemoryGateway
+from clean_python import Manage
+
+from .domain import Book
+from .domain import BookRepository
+
+
+class ManageBook(Manage[Book]):
+    def __init__(self, repo: Optional[BookRepository] = None):
+        if repo is None:
+            repo = BookRepository(InMemoryGateway([]))
+        self.repo = repo

+ 16 - 0
integration_tests/fastapi_example/domain.py

@@ -0,0 +1,16 @@
+from clean_python import Repository
+from clean_python import RootEntity
+from clean_python import ValueObject
+
+
+class Author(ValueObject):
+    name: str
+
+
+class Book(RootEntity):
+    author: Author
+    title: str
+
+
+class BookRepository(Repository[Book]):
+    pass

+ 70 - 0
integration_tests/fastapi_example/presentation.py

@@ -0,0 +1,70 @@
+from http import HTTPStatus
+from typing import Optional
+
+from fastapi import Depends
+from fastapi import Form
+from fastapi import Response
+from fastapi import UploadFile
+
+from clean_python import DoesNotExist
+from clean_python import Page
+from clean_python import ValueObject
+from clean_python.fastapi import delete
+from clean_python.fastapi import get
+from clean_python.fastapi import patch
+from clean_python.fastapi import post
+from clean_python.fastapi import RequestQuery
+from clean_python.fastapi import Resource
+from clean_python.fastapi import v
+
+from .application import ManageBook
+from .domain import Author
+from .domain import Book
+
+
+class BookCreate(ValueObject):
+    author: Author
+    title: str
+
+
+class BookUpdate(ValueObject):
+    author: Optional[Author] = None
+    title: Optional[str] = None
+
+
+class V1Books(Resource, version=v(1), name="books"):
+    def __init__(self):
+        self.manager = ManageBook()
+
+    @get("/books", response_model=Page[Book])
+    async def list(self, q: RequestQuery = Depends()):
+        return await self.manager.filter([], q.as_page_options())
+
+    @post("/books", status_code=HTTPStatus.CREATED, response_model=Book)
+    async def create(self, obj: BookCreate):
+        return await self.manager.create(obj.model_dump())
+
+    @get("/books/{id}", response_model=Book)
+    async def retrieve(self, id: int):
+        return await self.manager.retrieve(id)
+
+    @patch("/books/{id}", response_model=Book)
+    async def update(self, id: int, obj: BookUpdate):
+        return await self.manager.update(id, obj.model_dump(exclude_unset=True))
+
+    @delete("/books/{id}", status_code=HTTPStatus.NO_CONTENT, response_class=Response)
+    async def destroy(self, id: int):
+        if not await self.manager.destroy(id):
+            raise DoesNotExist("object", id)
+
+    @get("/text")
+    async def text(self):
+        return Response("foo", media_type="text/plain")
+
+    @post("/form", response_model=Author)
+    async def form(self, name: str = Form()):
+        return {"name": name}
+
+    @post("/file")
+    async def file(self, file: UploadFile):
+        return {file.filename: (await file.read()).decode()}

+ 62 - 0
integration_tests/test_api_gateway.py

@@ -0,0 +1,62 @@
+import pytest
+
+from clean_python import ctx
+from clean_python import DoesNotExist
+from clean_python import Json
+from clean_python import Tenant
+from clean_python.api_client import SyncApiGateway
+from clean_python.api_client import SyncApiProvider
+
+
+class BooksGateway(SyncApiGateway, path="v1/books/{id}"):
+    pass
+
+
+@pytest.fixture
+def provider(fastapi_example_app) -> SyncApiProvider:
+    ctx.tenant = Tenant(id=2, name="")
+    yield SyncApiProvider(fastapi_example_app + "/", lambda a, b: "token")
+    ctx.tenant = None
+
+
+@pytest.fixture
+def gateway(provider) -> SyncApiGateway:
+    return BooksGateway(provider)
+
+
+@pytest.fixture
+def book(gateway: SyncApiGateway):
+    return gateway.add({"title": "fixture", "author": {"name": "foo"}})
+
+
+def test_add(gateway: SyncApiGateway):
+    response = gateway.add({"title": "test_add", "author": {"name": "foo"}})
+    assert isinstance(response["id"], int)
+    assert response["title"] == "test_add"
+    assert response["author"] == {"name": "foo"}
+    assert response["created_at"] == response["updated_at"]
+
+
+def test_get(gateway: SyncApiGateway, book: Json):
+    response = gateway.get(book["id"])
+    assert response == book
+
+
+def test_remove_and_404(gateway: SyncApiGateway, book: Json):
+    assert gateway.remove(book["id"]) is True
+    assert gateway.get(book["id"]) is None
+    assert gateway.remove(book["id"]) is False
+
+
+def test_update(gateway: SyncApiGateway, book: Json):
+    response = gateway.update({"id": book["id"], "title": "test_update"})
+
+    assert response["id"] == book["id"]
+    assert response["title"] == "test_update"
+    assert response["author"] == {"name": "foo"}
+    assert response["created_at"] != response["updated_at"]
+
+
+def test_update_404(gateway: SyncApiGateway):
+    with pytest.raises(DoesNotExist):
+        gateway.update({"id": 123456, "title": "test_update_404"})

+ 84 - 0
integration_tests/test_api_provider.py

@@ -0,0 +1,84 @@
+from http import HTTPStatus
+
+import pytest
+
+from clean_python import ctx
+from clean_python import Tenant
+from clean_python.api_client import ApiException
+from clean_python.api_client import SyncApiProvider
+
+
+@pytest.fixture
+def provider(fastapi_example_app) -> SyncApiProvider:
+    ctx.tenant = Tenant(id=2, name="")
+    yield SyncApiProvider(fastapi_example_app + "/", lambda a, b: "token")
+    ctx.tenant = None
+
+
+def test_request_params(provider: SyncApiProvider):
+    response = provider.request("GET", "v1/books", params={"limit": 10, "offset": 2})
+
+    assert isinstance(response, dict)
+
+    assert response["limit"] == 10
+    assert response["offset"] == 2
+
+
+def test_request_json_body(provider: SyncApiProvider):
+    response = provider.request(
+        "POST", "v1/books", json={"title": "test_body", "author": {"name": "foo"}}
+    )
+
+    assert isinstance(response, dict)
+    assert response["title"] == "test_body"
+    assert response["author"] == {"name": "foo"}
+
+
+def test_request_form_body(provider: SyncApiProvider):
+    response = provider.request("POST", "v1/form", fields={"name": "foo"})
+
+    assert isinstance(response, dict)
+    assert response["name"] == "foo"
+
+
+def test_request_form_file(provider: SyncApiProvider):
+    response = provider.request("POST", "v1/file", fields={"file": ("x.txt", b"foo")})
+
+    assert isinstance(response, dict)
+    assert response["x.txt"] == "foo"
+
+
+@pytest.fixture
+def book(provider: SyncApiProvider):
+    return provider.request(
+        "POST", "v1/books", json={"title": "fixture", "author": {"name": "foo"}}
+    )
+
+
+def test_no_content(provider: SyncApiProvider, book):
+    response = provider.request("DELETE", f"v1/books/{book['id']}")
+
+    assert response is None
+
+
+def test_not_found(provider: SyncApiProvider):
+    with pytest.raises(ApiException) as e:
+        provider.request("GET", "v1/book")
+
+    assert e.value.status is HTTPStatus.NOT_FOUND
+    assert e.value.args[0] == {"detail": "Not Found"}
+
+
+def test_bad_request(provider: SyncApiProvider):
+    with pytest.raises(ApiException) as e:
+        provider.request("GET", "v1/books", params={"limit": "foo"})
+
+    assert e.value.status is HTTPStatus.BAD_REQUEST
+    assert e.value.args[0]["detail"][0]["loc"] == ["query", "limit"]
+
+
+def test_no_json_response(provider: SyncApiProvider):
+    with pytest.raises(ApiException) as e:
+        provider.request("GET", "v1/text")
+
+    assert e.value.args[0] == "Unexpected content type 'text/plain; charset=utf-8'"

+ 2 - 0
pyproject.toml

@@ -20,6 +20,8 @@ test = [
     "pytest-asyncio",
     "debugpy",
     "httpx",
+    "uvicorn",
+    "python-multipart"
 ]
 dramatiq = ["dramatiq"]
 fastapi = ["fastapi"]

+ 30 - 1
tests/api_client/test_sync_api_gateway.py

@@ -1,7 +1,10 @@
+from http import HTTPStatus
 from unittest import mock
 
 import pytest
 
+from clean_python import DoesNotExist
+from clean_python.api_client import ApiException
 from clean_python.api_client import SyncApiGateway
 from clean_python.api_client import SyncApiProvider
 
@@ -46,6 +49,32 @@ def test_remove(api_gateway: SyncApiGateway):
 
 
 def test_remove_does_not_exist(api_gateway: SyncApiGateway):
-    api_gateway.provider.request.return_value = None
+    api_gateway.provider.request.side_effect = ApiException(
+        {}, status=HTTPStatus.NOT_FOUND
+    )
     actual = api_gateway.remove(2)
     assert actual is False
+
+
+def test_update(api_gateway: SyncApiGateway):
+    actual = api_gateway.update({"id": 2, "foo": "bar"})
+
+    api_gateway.provider.request.assert_called_once_with(
+        "PATCH", "foo/2", json={"foo": "bar"}
+    )
+    assert actual is api_gateway.provider.request.return_value
+
+
+def test_update_no_id(api_gateway: SyncApiGateway):
+    with pytest.raises(DoesNotExist):
+        api_gateway.update({"foo": "bar"})
+
+    assert not api_gateway.provider.request.called
+
+
+def test_update_does_not_exist(api_gateway: SyncApiGateway):
+    api_gateway.provider.request.side_effect = ApiException(
+        {}, status=HTTPStatus.NOT_FOUND
+    )
+    with pytest.raises(DoesNotExist):
+        api_gateway.update({"id": 2, "foo": "bar"})

+ 17 - 16
tests/api_client/test_sync_api_provider.py

@@ -23,6 +23,7 @@ def response():
     response = mock.Mock()
     response.status = int(HTTPStatus.OK)
     response.headers = {"Content-Type": "application/json"}
+    response.data = b'{"foo": 2}'
     return response
 
 
@@ -44,12 +45,10 @@ def test_get(api_provider: SyncApiProvider, response):
     assert api_provider._pool.request.call_args[1] == dict(
         method="GET",
         url="http://testserver/foo",
-        json=None,
-        fields=None,
         headers={"Authorization": "Bearer tenant-2"},
         timeout=5.0,
     )
-    assert actual == response.json.return_value
+    assert actual == {"foo": 2}
 
 
 def test_post_json(api_provider: SyncApiProvider, response):
@@ -62,12 +61,14 @@ def test_post_json(api_provider: SyncApiProvider, response):
     assert api_provider._pool.request.call_args[1] == dict(
         method="POST",
         url="http://testserver/foo/bar",
-        json={"foo": 2},
-        fields=None,
-        headers={"Authorization": "Bearer tenant-2"},
+        body=b'{"foo": 2}',
+        headers={
+            "Content-Type": "application/json",
+            "Authorization": "Bearer tenant-2",
+        },
         timeout=5.0,
     )
-    assert actual == response.json.return_value
+    assert actual == {"foo": 2}
 
 
 @pytest.mark.parametrize(
@@ -103,7 +104,13 @@ def test_unexpected_content_type(api_provider: SyncApiProvider, response, status
         api_provider.request("GET", "bar")
 
     assert e.value.status is status
-    assert str(e.value) == "Unexpected content type 'text/plain'"
+    assert str(e.value) == f"{status}: Unexpected content type 'text/plain'"
+
+
+def test_json_variant_content_type(api_provider: SyncApiProvider, response):
+    response.headers["Content-Type"] = "application/something+json"
+    actual = api_provider.request("GET", "bar")
+    assert actual == {"foo": 2}
 
 
 def test_no_content(api_provider: SyncApiProvider, response):
@@ -111,16 +118,10 @@ def test_no_content(api_provider: SyncApiProvider, response):
     response.headers = {}
 
     actual = api_provider.request("DELETE", "bar/2")
-    assert actual is not None
-
-
-def test_404(api_provider: SyncApiProvider, response):
-    response.status = int(HTTPStatus.NOT_FOUND)
-    actual = api_provider.request("GET", "bar")
     assert actual is None
 
 
-@pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.FORBIDDEN])
+@pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND])
 def test_error_response(api_provider: SyncApiProvider, response, status):
     response.status = int(status)
 
@@ -128,7 +129,7 @@ def test_error_response(api_provider: SyncApiProvider, response, status):
         api_provider.request("GET", "bar")
 
     assert e.value.status is status
-    assert str(e.value) == str(response.json())
+    assert str(e.value) == str(int(status)) + ": {'foo': 2}"
 
 
 @mock.patch(MODULE + ".PoolManager", new=mock.Mock())

+ 1 - 1
tests/api_client/test_sync_files.py

@@ -94,7 +94,7 @@ def test_download_fileobj_no_multipart(pool, responses_single):
         download_fileobj("some-url", None, chunk_size=64, pool=pool)
 
     assert e.value.status == 200
-    assert str(e.value) == "The file server does not support multipart downloads."
+    assert str(e.value) == "200: The file server does not support multipart downloads."
 
 
 def test_download_fileobj_forbidden(pool, responses_single):