Quellcode durchsuchen

Disable multipart encoding in SyncApiProvider (#25)

Casper van der Wel vor 1 Jahr
Ursprung
Commit
d7a75d560b

+ 4 - 2
CHANGES.md

@@ -4,7 +4,9 @@
 0.6.9 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Disable the default multipart encoding in `SyncApiProvider`.
+
+- Added `file` parameter to `ApiProvider` to upload files (async is a TODO).
 
 
 0.6.8 (2023-10-10)
@@ -16,7 +18,7 @@
 0.6.7 (2023-10-09)
 ------------------
 
-- Adapt call signature of the `fetch_token` callable in `ApiProvicer`.
+- Adapt call signature of the `fetch_token` callable in `ApiProvider`.
 
 - Add `clean_python.oauth.client_credentials`.
 

+ 27 - 3
clean_python/api_client/api_provider.py

@@ -1,6 +1,8 @@
 import asyncio
 import re
 from http import HTTPStatus
+from io import BytesIO
+from typing import Any
 from typing import Awaitable
 from typing import Callable
 from typing import Dict
@@ -13,13 +15,15 @@ import aiohttp
 from aiohttp import ClientResponse
 from aiohttp import ClientSession
 from pydantic import AnyHttpUrl
+from pydantic import field_validator
 
 from clean_python import Json
+from clean_python import ValueObject
 
 from .exceptions import ApiException
 from .response import Response
 
-__all__ = ["ApiProvider"]
+__all__ = ["ApiProvider", "FileFormPost"]
 
 
 RETRY_STATUSES = frozenset({413, 429, 503})  # like in urllib3
@@ -57,6 +61,21 @@ def add_query_params(url: str, params: Optional[Json]) -> str:
     return url + "?" + urlencode(params, doseq=True)
 
 
+class FileFormPost(ValueObject):
+    file_name: str
+    file: Any  # typing of BinaryIO / BytesIO is hard!
+    field_name: str = "file"
+    content_type: str = "application/octet-stream"
+
+    @field_validator("file")
+    @classmethod
+    def validate_file(cls, v):
+        if isinstance(v, bytes):
+            return BytesIO(v)
+        assert hasattr(v, "read")  # poor-mans BinaryIO validation
+        return v
+
+
 class ApiProvider:
     """Basic JSON API provider with retry policy and bearer tokens.
 
@@ -94,8 +113,11 @@ class ApiProvider:
         params: Optional[Json],
         json: Optional[Json],
         fields: Optional[Json],
+        file: Optional[FileFormPost],
         timeout: float,
     ) -> ClientResponse:
+        if file is not None:
+            raise NotImplementedError("ApiProvider doesn't yet support file uploads")
         request_kwargs = {
             "method": method,
             "url": add_query_params(
@@ -130,10 +152,11 @@ class ApiProvider:
         params: Optional[Json] = None,
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
+        file: Optional[FileFormPost] = None,
         timeout: float = 5.0,
     ) -> Optional[Json]:
         response = await self._request_with_retry(
-            method, path, params, json, fields, timeout
+            method, path, params, json, fields, file, timeout
         )
         status = HTTPStatus(response.status)
         content_type = response.headers.get("Content-Type")
@@ -156,10 +179,11 @@ class ApiProvider:
         params: Optional[Json] = None,
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
+        file: Optional[FileFormPost] = None,
         timeout: float = 5.0,
     ) -> Response:
         response = await self._request_with_retry(
-            method, path, params, json, fields, timeout
+            method, path, params, json, fields, file, timeout
         )
         return Response(
             status=response.status,

+ 21 - 3
clean_python/api_client/sync_api_provider.py

@@ -12,6 +12,7 @@ from urllib3 import Retry
 from clean_python import Json
 
 from .api_provider import add_query_params
+from .api_provider import FileFormPost
 from .api_provider import is_json_content_type
 from .api_provider import is_success
 from .api_provider import join
@@ -55,6 +56,7 @@ class SyncApiProvider:
         params: Optional[Json],
         json: Optional[Json],
         fields: Optional[Json],
+        file: Optional[FileFormPost],
         timeout: float,
     ):
         headers = {}
@@ -68,11 +70,25 @@ class SyncApiProvider:
         # for urllib3<2, we dump json ourselves
         if json is not None and fields is not None:
             raise ValueError("Cannot both specify 'json' and 'fields'")
+        elif json is not None and file is not None:
+            raise ValueError("Cannot both specify 'json' and 'file'")
         elif json is not None:
             request_kwargs["body"] = json_lib.dumps(json).encode()
             headers["Content-Type"] = "application/json"
-        elif fields is not None:
+        elif fields is not None and file is None:
             request_kwargs["fields"] = fields
+            request_kwargs["encode_multipart"] = False
+        elif file is not None:
+            request_kwargs["fields"] = {
+                file.field_name: (
+                    file.file_name,
+                    file.file.read(),
+                    file.content_type,
+                ),
+                **(fields or {}),
+            }
+            request_kwargs["encode_multipart"] = True
+
         headers.update(self._fetch_token())
         return self._pool.request(headers=headers, **request_kwargs)
 
@@ -83,9 +99,10 @@ class SyncApiProvider:
         params: Optional[Json] = None,
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
+        file: Optional[FileFormPost] = None,
         timeout: float = 5.0,
     ) -> Optional[Json]:
-        response = self._request(method, path, params, json, fields, timeout)
+        response = self._request(method, path, params, json, fields, file, timeout)
         status = HTTPStatus(response.status)
         content_type = response.headers.get("Content-Type")
         if status is HTTPStatus.NO_CONTENT:
@@ -107,9 +124,10 @@ class SyncApiProvider:
         params: Optional[Json] = None,
         json: Optional[Json] = None,
         fields: Optional[Json] = None,
+        file: Optional[FileFormPost] = None,
         timeout: float = 5.0,
     ) -> Response:
-        response = self._request(method, path, params, json, fields, timeout)
+        response = self._request(method, path, params, json, fields, file, timeout)
         return Response(
             status=response.status,
             data=response.data,

+ 18 - 6
integration_tests/fastapi_example/presentation.py

@@ -6,6 +6,7 @@ from typing import Optional
 
 from fastapi import Depends
 from fastapi import Form
+from fastapi import Request
 from fastapi import Response
 from fastapi import UploadFile
 from fastapi.responses import JSONResponse
@@ -76,8 +77,8 @@ class V1Books(Resource, version=v(1), name="books"):
         return {"name": name}
 
     @post("/file")
-    async def file(self, file: UploadFile):
-        return {file.filename: (await file.read()).decode()}
+    async def file(self, file: UploadFile, description: str = Form()):
+        return {file.filename: (await file.read()).decode(), "description": description}
 
     @put("/urlencode/{name}", response_model=Author)
     async def urlencode(self, name: str):
@@ -86,19 +87,30 @@ class V1Books(Resource, version=v(1), name="books"):
     @post("/token")
     def token(
         self,
+        request: Request,
         grant_type: str = Form(),
         scope: str = Form(),
         credentials: HTTPBasicCredentials = Depends(basic),
     ):
         """For testing client credentials grant"""
+        if request.headers["Content-Type"] != "application/x-www-form-urlencoded":
+            return Response(status_code=HTTPStatus.METHOD_NOT_ALLOWED)
         if grant_type != "client_credentials":
-            return JSONResponse({"error": "invalid_grant"})
+            return JSONResponse(
+                {"error": "invalid_grant"}, status_code=HTTPStatus.BAD_REQUEST
+            )
         if credentials.username != "testclient":
-            return JSONResponse({"error": "invalid_client"})
+            return JSONResponse(
+                {"error": "invalid_client"}, status_code=HTTPStatus.BAD_REQUEST
+            )
         if credentials.password != "supersecret":
-            return JSONResponse({"error": "invalid_client"})
+            return JSONResponse(
+                {"error": "invalid_client"}, status_code=HTTPStatus.BAD_REQUEST
+            )
         if scope != "all":
-            return JSONResponse({"error": "invalid_grant"})
+            return JSONResponse(
+                {"error": "invalid_grant"}, status_code=HTTPStatus.BAD_REQUEST
+            )
         claims = {"user": "foo", "exp": int(time.time()) + 3600}
         payload = base64.b64encode(json.dumps(claims).encode()).decode()
         return {

+ 8 - 2
integration_tests/test_int_sync_api_provider.py

@@ -7,6 +7,7 @@ import pytest
 from clean_python import ctx
 from clean_python import Tenant
 from clean_python.api_client import ApiException
+from clean_python.api_client import FileFormPost
 from clean_python.api_client import SyncApiProvider
 
 
@@ -46,10 +47,15 @@ def test_request_form_body(provider: SyncApiProvider):
 
 
 def test_request_form_file(provider: SyncApiProvider):
-    response = provider.request("POST", "v1/file", fields={"file": ("x.txt", b"foo")})
+    response = provider.request(
+        "POST",
+        "v1/file",
+        fields={"description": "bla"},
+        file=FileFormPost(file_name="x.txt", file=b"foo"),
+    )
 
     assert isinstance(response, dict)
-    assert response["x.txt"] == "foo"
+    assert response == {"x.txt": "foo", "description": "bla"}
 
 
 @pytest.fixture

+ 44 - 0
tests/api_client/test_sync_api_provider.py

@@ -8,6 +8,7 @@ import pytest
 from clean_python import ctx
 from clean_python import Tenant
 from clean_python.api_client import ApiException
+from clean_python.api_client import FileFormPost
 from clean_python.api_client import SyncApiProvider
 
 MODULE = "clean_python.api_client.sync_api_provider"
@@ -159,3 +160,46 @@ def test_trailing_slash(api_provider: SyncApiProvider, path, trailing_slash, exp
         api_provider._pool.request.call_args[1]["url"]
         == "http://testserver/foo/" + expected
     )
+
+
+def test_post_file(api_provider: SyncApiProvider):
+    api_provider.request(
+        "POST",
+        "bar",
+        file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
+    )
+
+    assert api_provider._pool.request.call_count == 1
+
+    assert api_provider._pool.request.call_args[1] == dict(
+        method="POST",
+        url="http://testserver/foo/bar",
+        fields={"x": ("test.zip", b"foo", "application/octet-stream")},
+        headers={
+            "Authorization": "Bearer tenant-2",
+        },
+        timeout=5.0,
+        encode_multipart=True,
+    )
+
+
+def test_post_file_with_fields(api_provider: SyncApiProvider):
+    api_provider.request(
+        "POST",
+        "bar",
+        fields={"a": "b"},
+        file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
+    )
+
+    assert api_provider._pool.request.call_count == 1
+
+    assert api_provider._pool.request.call_args[1] == dict(
+        method="POST",
+        url="http://testserver/foo/bar",
+        fields={"a": "b", "x": ("test.zip", b"foo", "application/octet-stream")},
+        headers={
+            "Authorization": "Bearer tenant-2",
+        },
+        timeout=5.0,
+        encode_multipart=True,
+    )