Parcourir la source

Add async equivalents for API interfacing (#20)

Casper van der Wel il y a 1 an
Parent
commit
dc421efae4

+ 3 - 1
CHANGES.md

@@ -4,7 +4,9 @@
 0.6.5 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Added async `ApiProvider` and `ApiGateway`.
+
+- Added `request_raw` to `ApiProvider` for handling arbitrary responses.
 
 
 0.6.4 (2023-10-03)

+ 1 - 0
clean_python/api_client/__init__.py

@@ -2,3 +2,4 @@ from .api_gateway import *  # NOQA
 from .api_provider import *  # NOQA
 from .exceptions import *  # NOQA
 from .files import *  # NOQA
+from .sync_api_provider import *  # NOQA

+ 72 - 3
clean_python/api_client/api_gateway.py

@@ -5,15 +5,84 @@ from typing import Optional
 import inject
 
 from clean_python import DoesNotExist
+from clean_python import Gateway
 from clean_python import Id
 from clean_python import Json
 from clean_python import Mapper
+from clean_python import SyncGateway
 
-from .. import SyncGateway
-from .api_provider import SyncApiProvider
+from .api_provider import ApiProvider
 from .exceptions import ApiException
+from .sync_api_provider import SyncApiProvider
 
-__all__ = ["SyncApiGateway"]
+__all__ = ["ApiGateway", "SyncApiGateway"]
+
+
+class ApiGateway(Gateway):
+    path: str
+    mapper = Mapper()
+
+    def __init__(self, provider_override: Optional[ApiProvider] = None):
+        self.provider_override = provider_override
+
+    def __init_subclass__(cls, path: str) -> None:
+        assert not path.startswith("/")
+        assert "{id}" in path
+        cls.path = path
+        super().__init_subclass__()
+
+    @property
+    def provider(self) -> ApiProvider:
+        return self.provider_override or inject.instance(ApiProvider)
+
+    async def get(self, id: Id) -> Optional[Json]:
+        try:
+            result = await self.provider.request("GET", self.path.format(id=id))
+            assert result is not None
+            return self.mapper.to_internal(result)
+        except ApiException as e:
+            if e.status is HTTPStatus.NOT_FOUND:
+                return None
+            raise e
+
+    async def add(self, item: Json) -> Json:
+        item = self.mapper.to_external(item)
+        result = await self.provider.request("POST", self.path.format(id=""), json=item)
+        assert result is not None
+        return self.mapper.to_internal(result)
+
+    async def remove(self, id: Id) -> bool:
+        try:
+            await self.provider.request("DELETE", self.path.format(id=id)) is not None
+        except ApiException as e:
+            if e.status is HTTPStatus.NOT_FOUND:
+                return False
+            raise e
+        else:
+            return True
+
+    async def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        if if_unmodified_since is not None:
+            raise NotImplementedError("if_unmodified_since not implemented")
+        item = self.mapper.to_external(item)
+        id_ = item.pop("id", None)
+        if id_ is None:
+            raise DoesNotExist("resource", id_)
+        try:
+            result = await self.provider.request(
+                "PATCH", self.path.format(id=id_), json=item
+            )
+            assert result is not None
+            return self.mapper.to_internal(result)
+        except ApiException as e:
+            if e.status is HTTPStatus.NOT_FOUND:
+                raise DoesNotExist("resource", id_)
+            raise e
+
+
+# This is a copy-paste of ApiGateway:
 
 
 class SyncApiGateway(SyncGateway):

+ 74 - 24
clean_python/api_client/api_provider.py

@@ -1,4 +1,4 @@
-import json as json_lib
+import asyncio
 import re
 from http import HTTPStatus
 from typing import Callable
@@ -7,16 +7,21 @@ from urllib.parse import quote
 from urllib.parse import urlencode
 from urllib.parse import urljoin
 
+import aiohttp
+from aiohttp import ClientResponse
+from aiohttp import ClientSession
 from pydantic import AnyHttpUrl
-from urllib3 import PoolManager
-from urllib3 import Retry
 
 from clean_python import ctx
 from clean_python import Json
 
 from .exceptions import ApiException
+from .response import Response
 
-__all__ = ["SyncApiProvider"]
+__all__ = ["ApiProvider"]
+
+
+RETRY_STATUSES = frozenset({413, 429, 503})  # like in urllib3
 
 
 def is_success(status: HTTPStatus) -> bool:
@@ -49,7 +54,7 @@ def add_query_params(url: str, params: Optional[Json]) -> str:
     return url + "?" + urlencode(params, doseq=True)
 
 
-class SyncApiProvider:
+class ApiProvider:
     """Basic JSON API provider with retry policy and bearer tokens.
 
     The default retry policy has 3 retries with 1, 2, 4 second intervals.
@@ -64,43 +69,70 @@ class SyncApiProvider:
     def __init__(
         self,
         url: AnyHttpUrl,
-        fetch_token: Callable[[PoolManager, int], Optional[str]],
+        fetch_token: Callable[[ClientSession, int], Optional[str]],
         retries: int = 3,
         backoff_factor: float = 1.0,
     ):
         self._url = str(url)
         assert self._url.endswith("/")
         self._fetch_token = fetch_token
-        self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
+        assert retries > 0
+        self._retries = retries
+        self._backoff_factor = backoff_factor
+        self._session = ClientSession()
 
-    def request(
+    async def _request_with_retry(
         self,
         method: str,
         path: str,
-        params: Optional[Json] = None,
-        json: Optional[Json] = None,
-        fields: Optional[Json] = None,
-        timeout: float = 5.0,
-    ) -> Optional[Json]:
+        params: Optional[Json],
+        json: Optional[Json],
+        fields: Optional[Json],
+        timeout: float,
+    ) -> ClientResponse:
         assert ctx.tenant is not None
         headers = {}
         request_kwargs = {
             "method": method,
             "url": add_query_params(join(self._url, quote(path)), params),
             "timeout": timeout,
+            "json": json,
+            "data": fields,
         }
-        # for urllib3<2, we dump json ourselves
-        if json is not None and fields is not None:
-            raise ValueError("Cannot both specify 'json' and 'fields'")
-        elif json is not None:
-            request_kwargs["body"] = json_lib.dumps(json).encode()
-            headers["Content-Type"] = "application/json"
-        elif fields is not None:
-            request_kwargs["fields"] = fields
-        token = self._fetch_token(self._pool, ctx.tenant.id)
+        token = self._fetch_token(self._session, ctx.tenant.id)
         if token is not None:
             headers["Authorization"] = f"Bearer {token}"
-        response = self._pool.request(headers=headers, **request_kwargs)
+        for attempt in range(self._retries):
+            if attempt > 0:
+                backoff = self._backoff_factor * 2 ** (attempt - 1)
+                await asyncio.sleep(backoff)
+
+            try:
+                response = await self._session.request(
+                    headers=headers, **request_kwargs
+                )
+                await response.read()
+            except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
+                if attempt == self._retries - 1:
+                    raise  # propagate ClientError in case no retries left
+            else:
+                if response.status not in RETRY_STATUSES:
+                    return response  # on all non-retry statuses: return response
+
+        return response  # retries exceeded; return the (possibly error) response
+
+    async def request(
+        self,
+        method: str,
+        path: str,
+        params: Optional[Json] = None,
+        json: Optional[Json] = None,
+        fields: Optional[Json] = None,
+        timeout: float = 5.0,
+    ) -> Optional[Json]:
+        response = await self._request_with_retry(
+            method, path, params, json, fields, timeout
+        )
         status = HTTPStatus(response.status)
         content_type = response.headers.get("Content-Type")
         if status is HTTPStatus.NO_CONTENT:
@@ -109,8 +141,26 @@ class SyncApiProvider:
             raise ApiException(
                 f"Unexpected content type '{content_type}'", status=status
             )
-        body = json_lib.loads(response.data.decode())
+        body = await response.json()
         if is_success(status):
             return body
         else:
             raise ApiException(body, status=status)
+
+    async def request_raw(
+        self,
+        method: str,
+        path: str,
+        params: Optional[Json] = None,
+        json: Optional[Json] = None,
+        fields: Optional[Json] = None,
+        timeout: float = 5.0,
+    ) -> Response:
+        response = await self._request_with_retry(
+            method, path, params, json, fields, timeout
+        )
+        return Response(
+            status=response.status,
+            data=await response.read(),
+            content_type=response.headers.get("Content-Type"),
+        )

+ 12 - 0
clean_python/api_client/response.py

@@ -0,0 +1,12 @@
+from http import HTTPStatus
+from typing import Optional
+
+from clean_python import ValueObject
+
+__all__ = ["Response"]
+
+
+class Response(ValueObject):
+    status: HTTPStatus
+    data: bytes
+    content_type: Optional[str]

+ 115 - 0
clean_python/api_client/sync_api_provider.py

@@ -0,0 +1,115 @@
+import json as json_lib
+from http import HTTPStatus
+from typing import Callable
+from typing import Optional
+from urllib.parse import quote
+
+from pydantic import AnyHttpUrl
+from urllib3 import PoolManager
+from urllib3 import Retry
+
+from clean_python import ctx
+from clean_python import Json
+
+from .api_provider import add_query_params
+from .api_provider import is_json_content_type
+from .api_provider import is_success
+from .api_provider import join
+from .exceptions import ApiException
+from .response import Response
+
+__all__ = ["SyncApiProvider"]
+
+
+class SyncApiProvider:
+    """Basic JSON API provider with retry policy and bearer tokens.
+
+    The default retry policy has 3 retries with 1, 2, 4 second intervals.
+
+    Args:
+        url: The url of the API (with trailing slash)
+        fetch_token: Callable that returns a token for a tenant id
+        retries: Total number of retries per request
+        backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
+    """
+
+    def __init__(
+        self,
+        url: AnyHttpUrl,
+        fetch_token: Callable[[PoolManager, int], Optional[str]],
+        retries: int = 3,
+        backoff_factor: float = 1.0,
+    ):
+        self._url = str(url)
+        assert self._url.endswith("/")
+        self._fetch_token = fetch_token
+        self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
+
+    def _request(
+        self,
+        method: str,
+        path: str,
+        params: Optional[Json],
+        json: Optional[Json],
+        fields: Optional[Json],
+        timeout: float,
+    ):
+        assert ctx.tenant is not None
+        headers = {}
+        request_kwargs = {
+            "method": method,
+            "url": add_query_params(join(self._url, quote(path)), params),
+            "timeout": timeout,
+        }
+        # for urllib3<2, we dump json ourselves
+        if json is not None and fields is not None:
+            raise ValueError("Cannot both specify 'json' and 'fields'")
+        elif json is not None:
+            request_kwargs["body"] = json_lib.dumps(json).encode()
+            headers["Content-Type"] = "application/json"
+        elif fields is not None:
+            request_kwargs["fields"] = fields
+        token = self._fetch_token(self._pool, ctx.tenant.id)
+        if token is not None:
+            headers["Authorization"] = f"Bearer {token}"
+        return self._pool.request(headers=headers, **request_kwargs)
+
+    def request(
+        self,
+        method: str,
+        path: str,
+        params: Optional[Json] = None,
+        json: Optional[Json] = None,
+        fields: Optional[Json] = None,
+        timeout: float = 5.0,
+    ) -> Optional[Json]:
+        response = self._request(method, path, params, json, fields, timeout)
+        status = HTTPStatus(response.status)
+        content_type = response.headers.get("Content-Type")
+        if status is HTTPStatus.NO_CONTENT:
+            return None
+        if not is_json_content_type(content_type):
+            raise ApiException(
+                f"Unexpected content type '{content_type}'", status=status
+            )
+        body = json_lib.loads(response.data.decode())
+        if is_success(status):
+            return body
+        else:
+            raise ApiException(body, status=status)
+
+    def request_raw(
+        self,
+        method: str,
+        path: str,
+        params: Optional[Json] = None,
+        json: Optional[Json] = None,
+        fields: Optional[Json] = None,
+        timeout: float = 5.0,
+    ) -> Response:
+        response = self._request(method, path, params, json, fields, timeout)
+        return Response(
+            status=response.status,
+            data=response.data,
+            content_type=response.headers.get("Content-Type"),
+        )

+ 62 - 0
integration_tests/test_int_api_gateway.py

@@ -0,0 +1,62 @@
+import pytest
+
+from clean_python import ctx
+from clean_python import DoesNotExist
+from clean_python import Json
+from clean_python import Tenant
+from clean_python.api_client import ApiGateway
+from clean_python.api_client import ApiProvider
+
+
+class BooksGateway(ApiGateway, path="v1/books/{id}"):
+    pass
+
+
+@pytest.fixture
+def provider(fastapi_example_app) -> ApiProvider:
+    ctx.tenant = Tenant(id=2, name="")
+    yield ApiProvider(fastapi_example_app + "/", lambda a, b: "token")
+    ctx.tenant = None
+
+
+@pytest.fixture
+def gateway(provider) -> ApiGateway:
+    return BooksGateway(provider)
+
+
+@pytest.fixture
+async def book(gateway: ApiGateway):
+    return await gateway.add({"title": "fixture", "author": {"name": "foo"}})
+
+
+async def test_add(gateway: ApiGateway):
+    response = await gateway.add({"title": "test_add", "author": {"name": "foo"}})
+    assert isinstance(response["id"], int)
+    assert response["title"] == "test_add"
+    assert response["author"] == {"name": "foo"}
+    assert response["created_at"] == response["updated_at"]
+
+
+async def test_get(gateway: ApiGateway, book: Json):
+    response = await gateway.get(book["id"])
+    assert response == book
+
+
+async def test_remove_and_404(gateway: ApiGateway, book: Json):
+    assert await gateway.remove(book["id"]) is True
+    assert await gateway.get(book["id"]) is None
+    assert await gateway.remove(book["id"]) is False
+
+
+async def test_update(gateway: ApiGateway, book: Json):
+    response = await gateway.update({"id": book["id"], "title": "test_update"})
+
+    assert response["id"] == book["id"]
+    assert response["title"] == "test_update"
+    assert response["author"] == {"name": "foo"}
+    assert response["created_at"] != response["updated_at"]
+
+
+async def test_update_404(gateway: ApiGateway):
+    with pytest.raises(DoesNotExist):
+        await gateway.update({"id": 123456, "title": "test_update_404"})

+ 103 - 0
integration_tests/test_int_api_provider.py

@@ -0,0 +1,103 @@
+from http import HTTPStatus
+
+import pytest
+
+from clean_python import ctx
+from clean_python import Tenant
+from clean_python.api_client import ApiException
+from clean_python.api_client import ApiProvider
+
+
+@pytest.fixture
+def provider(fastapi_example_app) -> ApiProvider:
+    ctx.tenant = Tenant(id=2, name="")
+    yield ApiProvider(fastapi_example_app + "/", lambda a, b: "token")
+    ctx.tenant = None
+
+
+async def test_request_params(provider: ApiProvider):
+    response = await provider.request(
+        "GET", "v1/books", params={"limit": 10, "offset": 2}
+    )
+
+    assert isinstance(response, dict)
+
+    assert response["limit"] == 10
+    assert response["offset"] == 2
+
+
+async def test_request_json_body(provider: ApiProvider):
+    response = await provider.request(
+        "POST", "v1/books", json={"title": "test_body", "author": {"name": "foo"}}
+    )
+
+    assert isinstance(response, dict)
+    assert response["title"] == "test_body"
+    assert response["author"] == {"name": "foo"}
+
+
+async def test_request_form_body(provider: ApiProvider):
+    response = await provider.request("POST", "v1/form", fields={"name": "foo"})
+
+    assert isinstance(response, dict)
+    assert response["name"] == "foo"
+
+
+# files are not supported (yet)
+#
+# async def test_request_form_file(provider: ApiProvider):
+#     response = await provider.request("POST", "v1/file", fields={"file": ("x.txt", b"foo")})
+
+#     assert isinstance(response, dict)
+#     assert response["x.txt"] == "foo"
+
+
+@pytest.fixture
+async def book(provider: ApiProvider):
+    return await provider.request(
+        "POST", "v1/books", json={"title": "fixture", "author": {"name": "foo"}}
+    )
+
+
+async def test_no_content(provider: ApiProvider, book):
+    response = await provider.request("DELETE", f"v1/books/{book['id']}")
+
+    assert response is None
+
+
+async def test_not_found(provider: ApiProvider):
+    with pytest.raises(ApiException) as e:
+        await provider.request("GET", "v1/book")
+
+    assert e.value.status is HTTPStatus.NOT_FOUND
+    assert e.value.args[0] == {"detail": "Not Found"}
+
+
+async def test_bad_request(provider: ApiProvider):
+    with pytest.raises(ApiException) as e:
+        await provider.request("GET", "v1/books", params={"limit": "foo"})
+
+    assert e.value.status is HTTPStatus.BAD_REQUEST
+    assert e.value.args[0]["detail"][0]["loc"] == ["query", "limit"]
+
+
+async def test_no_json_response(provider: ApiProvider):
+    with pytest.raises(ApiException) as e:
+        await provider.request("GET", "v1/text")
+
+    assert e.value.args[0] == "Unexpected content type 'text/plain; charset=utf-8'"
+
+
+async def test_urlencode(provider: ApiProvider):
+    response = await provider.request("PUT", "v1/urlencode/x?")
+
+    assert isinstance(response, dict)
+    assert response["name"] == "x?"
+
+
+async def test_request_raw(provider: ApiProvider, book):
+    response = await provider.request_raw("GET", f"v1/books/{book['id']}")
+
+    assert response.status is HTTPStatus.OK
+    assert len(response.data) > 0
+    assert response.content_type == "application/json"

+ 2 - 0
integration_tests/test_api_gateway.py → integration_tests/test_int_sync_api_gateway.py

@@ -1,3 +1,5 @@
+# This module is a copy paste of test_int_api_gateway.py
+
 import pytest
 
 from clean_python import ctx

+ 10 - 0
integration_tests/test_api_provider.py → integration_tests/test_int_sync_api_provider.py

@@ -1,3 +1,5 @@
+# This module is a copy paste of test_int_api_provider.py
+
 from http import HTTPStatus
 
 import pytest
@@ -89,3 +91,11 @@ def test_urlencode(provider: SyncApiProvider):
 
     assert isinstance(response, dict)
     assert response["name"] == "x?"
+
+
+def test_request_raw(provider: SyncApiProvider, book):
+    response = provider.request_raw("GET", f"v1/books/{book['id']}")
+
+    assert response.status is HTTPStatus.OK
+    assert len(response.data) > 0
+    assert response.content_type == "application/json"

+ 1 - 1
pyproject.toml

@@ -30,7 +30,7 @@ celery = ["pika"]
 fluentbit = ["fluent-logger"]
 sql = ["sqlalchemy==2.*", "asyncpg"]
 s3 = ["aioboto3", "boto3"]
-api_client = ["urllib3"]
+api_client = ["aiohttp", "urllib3"]
 profiler = ["yappi"]
 debugger = ["debugpy"]
 

+ 131 - 0
tests/api_client/test_api_gateway.py

@@ -0,0 +1,131 @@
+from http import HTTPStatus
+from unittest import mock
+
+import pytest
+
+from clean_python import DoesNotExist
+from clean_python import Json
+from clean_python import Mapper
+from clean_python.api_client import ApiException
+from clean_python.api_client import ApiGateway
+from clean_python.api_client import ApiProvider
+
+
+class TstApiGateway(ApiGateway, path="foo/{id}"):
+    pass
+
+
+@pytest.fixture
+def api_provider():
+    return mock.MagicMock(spec_set=ApiProvider)
+
+
+@pytest.fixture
+def api_gateway(api_provider) -> ApiGateway:
+    return TstApiGateway(api_provider)
+
+
+async def test_get(api_gateway: ApiGateway):
+    actual = await api_gateway.get(14)
+
+    api_gateway.provider.request.assert_called_once_with("GET", "foo/14")
+    assert actual is api_gateway.provider.request.return_value
+
+
+async def test_add(api_gateway: ApiGateway):
+    actual = await api_gateway.add({"foo": 2})
+
+    api_gateway.provider.request.assert_called_once_with(
+        "POST", "foo/", json={"foo": 2}
+    )
+    assert actual is api_gateway.provider.request.return_value
+
+
+async def test_remove(api_gateway: ApiGateway):
+    actual = await api_gateway.remove(2)
+
+    api_gateway.provider.request.assert_called_once_with("DELETE", "foo/2")
+    assert actual is True
+
+
+async def test_remove_does_not_exist(api_gateway: ApiGateway):
+    api_gateway.provider.request.side_effect = ApiException(
+        {}, status=HTTPStatus.NOT_FOUND
+    )
+    actual = await api_gateway.remove(2)
+    assert actual is False
+
+
+async def test_update(api_gateway: ApiGateway):
+    actual = await api_gateway.update({"id": 2, "foo": "bar"})
+
+    api_gateway.provider.request.assert_called_once_with(
+        "PATCH", "foo/2", json={"foo": "bar"}
+    )
+    assert actual is api_gateway.provider.request.return_value
+
+
+async def test_update_no_id(api_gateway: ApiGateway):
+    with pytest.raises(DoesNotExist):
+        await api_gateway.update({"foo": "bar"})
+
+    assert not api_gateway.provider.request.called
+
+
+async def test_update_does_not_exist(api_gateway: ApiGateway):
+    api_gateway.provider.request.side_effect = ApiException(
+        {}, status=HTTPStatus.NOT_FOUND
+    )
+    with pytest.raises(DoesNotExist):
+        await api_gateway.update({"id": 2, "foo": "bar"})
+
+
+class TstMapper(Mapper):
+    def to_external(self, internal: Json) -> Json:
+        result = {}
+        if internal.get("id") is not None:
+            result["id"] = internal["id"]
+        if internal.get("name") is not None:
+            result["name"] = internal["name"].upper()
+        return result
+
+    def to_internal(self, external: Json) -> Json:
+        return {"id": external["id"], "name": external["name"].lower()}
+
+
+class TstMappedApiGateway(ApiGateway, path="foo/{id}"):
+    mapper = TstMapper()
+
+
+@pytest.fixture
+def mapped_api_gateway(api_provider) -> ApiGateway:
+    return TstMappedApiGateway(api_provider)
+
+
+async def test_get_with_mapper(mapped_api_gateway: ApiGateway):
+    mapped_api_gateway.provider.request.return_value = {"id": 14, "name": "FOO"}
+
+    assert await mapped_api_gateway.get(14) == {"id": 14, "name": "foo"}
+
+
+async def test_add_with_mapper(mapped_api_gateway: ApiGateway):
+    mapped_api_gateway.provider.request.return_value = {"id": 3, "name": "FOO"}
+
+    assert await mapped_api_gateway.add({"name": "foo"}) == {"id": 3, "name": "foo"}
+
+    mapped_api_gateway.provider.request.assert_called_once_with(
+        "POST", "foo/", json={"name": "FOO"}
+    )
+
+
+async def test_update_with_mapper(mapped_api_gateway: ApiGateway):
+    mapped_api_gateway.provider.request.return_value = {"id": 2, "name": "BAR"}
+
+    assert await mapped_api_gateway.update({"id": 2, "name": "bar"}) == {
+        "id": 2,
+        "name": "bar",
+    }
+
+    mapped_api_gateway.provider.request.assert_called_once_with(
+        "PATCH", "foo/2", json={"name": "BAR"}
+    )

+ 144 - 0
tests/api_client/test_api_provider.py

@@ -0,0 +1,144 @@
+from http import HTTPStatus
+from unittest import mock
+
+import pytest
+from aiohttp import ClientSession
+
+from clean_python import ctx
+from clean_python import Tenant
+from clean_python.api_client import ApiException
+from clean_python.api_client import ApiProvider
+
+MODULE = "clean_python.api_client.api_provider"
+
+
+@pytest.fixture
+def tenant() -> Tenant:
+    ctx.tenant = Tenant(id=2, name="")
+    yield ctx.tenant
+    ctx.tenant = None
+
+
+@pytest.fixture
+def response():
+    # this mocks the aiohttp.ClientResponse:
+    response = mock.Mock()
+    response.status = int(HTTPStatus.OK)
+    response.headers = {"Content-Type": "application/json"}
+    response.json = mock.AsyncMock(return_value={"foo": 2})
+    response.read = mock.AsyncMock()
+    return response
+
+
+@pytest.fixture
+def api_provider(tenant, response) -> ApiProvider:
+    request = mock.AsyncMock()
+    with mock.patch.object(ClientSession, "request", new=request):
+        api_provider = ApiProvider(
+            url="http://testserver/foo/",
+            fetch_token=lambda a, b: f"tenant-{b}",
+        )
+        api_provider._session.request.return_value = response
+        yield api_provider
+
+
+async def test_get(api_provider: ApiProvider, response):
+    actual = await api_provider.request("GET", "")
+
+    assert api_provider._session.request.call_count == 1
+    assert api_provider._session.request.call_args[1] == dict(
+        method="GET",
+        url="http://testserver/foo",
+        headers={"Authorization": "Bearer tenant-2"},
+        timeout=5.0,
+        data=None,
+        json=None,
+    )
+    assert actual == {"foo": 2}
+
+
+async def test_post_json(api_provider: ApiProvider, response):
+    response.status == int(HTTPStatus.CREATED)
+    api_provider._session.request.return_value = response
+    actual = await api_provider.request("POST", "bar", json={"foo": 2})
+
+    assert api_provider._session.request.call_count == 1
+
+    assert api_provider._session.request.call_args[1] == dict(
+        method="POST",
+        url="http://testserver/foo/bar",
+        data=None,
+        json={"foo": 2},
+        headers={
+            "Authorization": "Bearer tenant-2",
+        },
+        timeout=5.0,
+    )
+    assert actual == {"foo": 2}
+
+
+@pytest.mark.parametrize(
+    "path,params,expected_url",
+    [
+        ("", None, "http://testserver/foo"),
+        ("bar", None, "http://testserver/foo/bar"),
+        ("bar/", None, "http://testserver/foo/bar"),
+        ("", {"a": 2}, "http://testserver/foo?a=2"),
+        ("bar", {"a": 2}, "http://testserver/foo/bar?a=2"),
+        ("bar/", {"a": 2}, "http://testserver/foo/bar?a=2"),
+        ("", {"a": [1, 2]}, "http://testserver/foo?a=1&a=2"),
+        ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
+    ],
+)
+async def test_url(api_provider: ApiProvider, path, params, expected_url):
+    await api_provider.request("GET", path, params=params)
+    assert api_provider._session.request.call_args[1]["url"] == expected_url
+
+
+async def test_timeout(api_provider: ApiProvider):
+    await api_provider.request("POST", "bar", timeout=2.1)
+    assert api_provider._session.request.call_args[1]["timeout"] == 2.1
+
+
+@pytest.mark.parametrize(
+    "status", [HTTPStatus.OK, HTTPStatus.NOT_FOUND, HTTPStatus.INTERNAL_SERVER_ERROR]
+)
+async def test_unexpected_content_type(api_provider: ApiProvider, response, status):
+    response.status = int(status)
+    response.headers["Content-Type"] = "text/plain"
+    with pytest.raises(ApiException) as e:
+        await api_provider.request("GET", "bar")
+
+    assert e.value.status is status
+    assert str(e.value) == f"{status}: Unexpected content type 'text/plain'"
+
+
+async def test_json_variant_content_type(api_provider: ApiProvider, response):
+    response.headers["Content-Type"] = "application/something+json"
+    actual = await api_provider.request("GET", "bar")
+    assert actual == {"foo": 2}
+
+
+async def test_no_content(api_provider: ApiProvider, response):
+    response.status = int(HTTPStatus.NO_CONTENT)
+    response.headers = {}
+
+    actual = await api_provider.request("DELETE", "bar/2")
+    assert actual is None
+
+
+@pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND])
+async def test_error_response(api_provider: ApiProvider, response, status):
+    response.status = int(status)
+
+    with pytest.raises(ApiException) as e:
+        await api_provider.request("GET", "bar")
+
+    assert e.value.status is status
+    assert str(e.value) == str(int(status)) + ": {'foo': 2}"
+
+
+async def test_no_token(api_provider: ApiProvider):
+    api_provider._fetch_token = lambda a, b: None
+    await api_provider.request("GET", "")
+    assert api_provider._session.request.call_args[1]["headers"] == {}

+ 2 - 2
tests/api_client/test_sync_api_gateway.py

@@ -1,3 +1,5 @@
+# This module is a copy paste of test_api_gateway.py
+
 from http import HTTPStatus
 from unittest import mock
 
@@ -10,8 +12,6 @@ from clean_python.api_client import ApiException
 from clean_python.api_client import SyncApiGateway
 from clean_python.api_client import SyncApiProvider
 
-MODULE = "clean_python.api_client.api_provider"
-
 
 class TstSyncApiGateway(SyncApiGateway, path="foo/{id}"):
     pass

+ 3 - 1
tests/api_client/test_sync_api_provider.py

@@ -1,3 +1,5 @@
+# This module is a copy paste of test_api_provider.py
+
 from http import HTTPStatus
 from unittest import mock
 
@@ -8,7 +10,7 @@ from clean_python import Tenant
 from clean_python.api_client import ApiException
 from clean_python.api_client import SyncApiProvider
 
-MODULE = "clean_python.api_client.api_provider"
+MODULE = "clean_python.api_client.sync_api_provider"
 
 
 @pytest.fixture