Bladeren bron

Add headers parameter to ApiProvider.request (#42)

Casper van der Wel 1 jaar geleden
bovenliggende
commit
750819f5aa

+ 2 - 0
CHANGES.md

@@ -8,6 +8,8 @@
 
 - Added CCTokenGateway.fetch_headers()
 
+- Added optional 'headers' parameter to ApiProvider.
+
 
 0.9.1 (2023-11-23)
 ------------------

+ 12 - 4
clean_python/api_client/api_provider.py

@@ -123,6 +123,7 @@ class ApiProvider:
         json: Optional[Json],
         fields: Optional[Json],
         file: Optional[FileFormPost],
+        headers: Optional[Dict[str, str]],
         timeout: float,
     ) -> ClientResponse:
         if file is not None:
@@ -136,15 +137,20 @@ class ApiProvider:
             "json": json,
             "data": fields,
         }
+        actual_headers = {}
         if self._headers_factory is not None:
-            request_kwargs["headers"] = await self._headers_factory()
+            actual_headers.update(await self._headers_factory())
+        if headers:
+            actual_headers.update(headers)
         for attempt in range(self._retries):
             if attempt > 0:
                 backoff = self._backoff_factor * 2 ** (attempt - 1)
                 await asyncio.sleep(backoff)
 
             try:
-                response = await self._session.request(**request_kwargs)
+                response = await self._session.request(
+                    headers=actual_headers, **request_kwargs
+                )
                 await response.read()
             except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
                 if attempt == self._retries - 1:
@@ -163,10 +169,11 @@ class ApiProvider:
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
         file: Optional[FileFormPost] = None,
+        headers: Optional[Dict[str, str]] = None,
         timeout: float = 5.0,
     ) -> Optional[Json]:
         response = await self._request_with_retry(
-            method, path, params, json, fields, file, timeout
+            method, path, params, json, fields, file, headers, timeout
         )
         status = HTTPStatus(response.status)
         content_type = response.headers.get("Content-Type")
@@ -188,10 +195,11 @@ class ApiProvider:
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
         file: Optional[FileFormPost] = None,
+        headers: Optional[Dict[str, str]] = None,
         timeout: float = 5.0,
     ) -> Response:
         response = await self._request_with_retry(
-            method, path, params, json, fields, file, timeout
+            method, path, params, json, fields, file, headers, timeout
         )
         return Response(
             status=response.status,

+ 16 - 7
clean_python/api_client/sync_api_provider.py

@@ -58,9 +58,14 @@ class SyncApiProvider:
         json: Optional[Json],
         fields: Optional[Json],
         file: Optional[FileFormPost],
+        headers: Optional[Dict[str, str]],
         timeout: float,
     ):
-        headers = {}
+        actual_headers = {}
+        if self._headers_factory is not None:
+            actual_headers.update(self._headers_factory())
+        if headers:
+            actual_headers.update(headers)
         request_kwargs = {
             "method": method,
             "url": add_query_params(
@@ -75,7 +80,7 @@ class SyncApiProvider:
             raise ValueError("Cannot both specify 'json' and 'file'")
         elif json is not None:
             request_kwargs["body"] = json_lib.dumps(json).encode()
-            headers["Content-Type"] = "application/json"
+            actual_headers["Content-Type"] = "application/json"
         elif fields is not None and file is None:
             request_kwargs["fields"] = fields
             request_kwargs["encode_multipart"] = False
@@ -90,9 +95,7 @@ class SyncApiProvider:
             }
             request_kwargs["encode_multipart"] = True
 
-        if self._headers_factory is not None:
-            headers.update(self._headers_factory())
-        return self._pool.request(headers=headers, **request_kwargs)
+        return self._pool.request(headers=actual_headers, **request_kwargs)
 
     def request(
         self,
@@ -102,9 +105,12 @@ class SyncApiProvider:
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
         file: Optional[FileFormPost] = None,
+        headers: Optional[Dict[str, str]] = None,
         timeout: float = 5.0,
     ) -> Optional[Json]:
-        response = self._request(method, path, params, json, fields, file, timeout)
+        response = self._request(
+            method, path, params, json, fields, file, headers, timeout
+        )
         status = HTTPStatus(response.status)
         content_type = response.headers.get("Content-Type")
         if status is HTTPStatus.NO_CONTENT:
@@ -125,9 +131,12 @@ class SyncApiProvider:
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
         file: Optional[FileFormPost] = None,
+        headers: Optional[Dict[str, str]] = None,
         timeout: float = 5.0,
     ) -> Response:
-        response = self._request(method, path, params, json, fields, file, timeout)
+        response = self._request(
+            method, path, params, json, fields, file, headers, timeout
+        )
         return Response(
             status=response.status,
             data=response.data,

+ 15 - 0
tests/api_client/test_api_provider.py

@@ -187,3 +187,18 @@ async def test_conflict_with_message(api_provider: ApiProvider, response):
 
     with pytest.raises(Conflict, match="foo"):
         await api_provider.request("GET", "bar")
+
+
+async def test_custom_header(api_provider: ApiProvider):
+    await api_provider.request("POST", "bar", headers={"foo": "bar"})
+    assert api_provider._session.request.call_args[1]["headers"] == {
+        "foo": "bar",
+        **(await api_provider._headers_factory()),
+    }
+
+
+async def test_custom_header_precedes(api_provider: ApiProvider):
+    await api_provider.request("POST", "bar", headers={"Authorization": "bar"})
+    assert (
+        api_provider._session.request.call_args[1]["headers"]["Authorization"] == "bar"
+    )

+ 13 - 0
tests/api_client/test_sync_api_provider.py

@@ -221,3 +221,16 @@ def test_conflict_with_message(api_provider: SyncApiProvider, response):
 
     with pytest.raises(Conflict, match="foo"):
         api_provider.request("GET", "bar")
+
+
+def test_custom_header(api_provider: SyncApiProvider):
+    api_provider.request("POST", "bar", headers={"foo": "bar"})
+    assert api_provider._pool.request.call_args[1]["headers"] == {
+        "foo": "bar",
+        **api_provider._headers_factory(),
+    }
+
+
+def test_custom_header_precedes(api_provider: SyncApiProvider):
+    api_provider.request("POST", "bar", headers={"Authorization": "bar"})
+    assert api_provider._pool.request.call_args[1]["headers"]["Authorization"] == "bar"