Przeglądaj źródła

Inject headers into download/upload functions (#40)

Casper van der Wel 1 rok temu
rodzic
commit
125c5893ef

+ 6 - 1
CHANGES.md

@@ -4,7 +4,12 @@
 0.9.1 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Renamed 'fetch_token' parameter in api client to 'headers_factory' and
+  made it optional.
+
+- Added 'headers_factory' to upload/download functions.
+
+- Allow 201 "CREATED" status code in upload_file.
 
 
 0.9.0 (2023-11-22)

+ 6 - 4
clean_python/api_client/api_provider.py

@@ -91,15 +91,16 @@ class ApiProvider:
 
     Args:
         url: The url of the API (with trailing slash)
-        fetch_token: Coroutine that returns headers for authorization
+        headers_factory: Coroutine that returns headers (for e.g. authorization)
         retries: Total number of retries per request
         backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
+        trailing_slash: Wether to automatically add or remove trailing slashes.
     """
 
     def __init__(
         self,
         url: AnyHttpUrl,
-        fetch_token: Callable[[], Awaitable[Dict[str, str]]],
+        headers_factory: Optional[Callable[[], Awaitable[Dict[str, str]]]] = None,
         retries: int = 3,
         backoff_factor: float = 1.0,
         trailing_slash: bool = False,
@@ -107,7 +108,7 @@ class ApiProvider:
         self._url = str(url)
         if not self._url.endswith("/"):
             self._url += "/"
-        self._fetch_token = fetch_token
+        self._headers_factory = headers_factory
         assert retries > 0
         self._retries = retries
         self._backoff_factor = backoff_factor
@@ -134,8 +135,9 @@ class ApiProvider:
             "timeout": timeout,
             "json": json,
             "data": fields,
-            "headers": await self._fetch_token(),
         }
+        if self._headers_factory is not None:
+            request_kwargs["headers"] = await self._headers_factory()
         for attempt in range(self._retries):
             if attempt > 0:
                 backoff = self._backoff_factor * 2 ** (attempt - 1)

+ 24 - 4
clean_python/api_client/files.py

@@ -3,9 +3,11 @@ import hashlib
 import logging
 import os
 import re
+from http import HTTPStatus
 from pathlib import Path
 from typing import BinaryIO
 from typing import Callable
+from typing import Dict
 from typing import Optional
 from typing import Tuple
 from typing import Union
@@ -58,6 +60,7 @@ def download_file(
     timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
     pool: Optional[urllib3.PoolManager] = None,
     callback_func: Optional[Callable[[int, int], None]] = None,
+    headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
 ) -> Tuple[Path, int]:
     """Download a file to a specified path on disk.
 
@@ -75,6 +78,7 @@ def download_file(
             created with a retry policy of 3 retries after 1, 2, 4 seconds.
         callback_func: optional function used to receive: bytes_downloaded, total_bytes
             for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
+        headers_factory: optional function to inject headers
 
     Returns:
         Tuple of file path, total number of downloaded bytes.
@@ -105,6 +109,7 @@ def download_file(
                 timeout=timeout,
                 pool=pool,
                 callback_func=callback_func,
+                headers_factory=headers_factory,
             )
     except Exception:
         # Clean up a partially downloaded file
@@ -124,6 +129,7 @@ def download_fileobj(
     timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
     pool: Optional[urllib3.PoolManager] = None,
     callback_func: Optional[Callable[[int, int], None]] = None,
+    headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
 ) -> int:
     """Download a url to a file object using multiple requests.
 
@@ -139,6 +145,7 @@ def download_fileobj(
             created with a retry policy of 3 retries after 1, 2, 4 seconds.
         callback_func: optional function used to receive: bytes_downloaded, total_bytes
             for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
+        headers_factory: optional function to inject headers
 
     Returns:
         The total number of downloaded bytes.
@@ -156,6 +163,12 @@ def download_fileobj(
     """
     if pool is None:
         pool = get_pool()
+    if headers_factory is not None:
+        base_headers = headers_factory()
+        if any(x.lower() == "range" for x in base_headers):
+            raise ValueError("Cannot set the Range header through header_factory")
+    else:
+        base_headers = {}
 
     # Our strategy here is to just start downloading chunks while monitoring
     # the Content-Range header to check if we're done. Although we could get
@@ -165,7 +178,7 @@ def download_fileobj(
     while True:
         # download a chunk
         stop = start + chunk_size - 1
-        headers = {"Range": "bytes={}-{}".format(start, stop)}
+        headers = {"Range": "bytes={}-{}".format(start, stop), **base_headers}
 
         response = pool.request(
             "GET",
@@ -173,12 +186,12 @@ def download_fileobj(
             headers=headers,
             timeout=timeout,
         )
-        if response.status == 200:
+        if response.status == HTTPStatus.OK:
             raise ApiException(
                 "The file server does not support multipart downloads.",
                 status=response.status,
             )
-        elif response.status != 206:
+        elif response.status != HTTPStatus.PARTIAL_CONTENT:
             raise ApiException("Unexpected status", status=response.status)
 
         # write to file
@@ -210,6 +223,7 @@ def upload_file(
     pool: Optional[urllib3.PoolManager] = None,
     md5: Optional[bytes] = None,
     callback_func: Optional[Callable[[int, int], None]] = None,
+    headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
 ) -> int:
     """Upload a file at specified file path to a url.
 
@@ -230,6 +244,7 @@ def upload_file(
             should be included in the signing procedure.
         callback_func: optional function used to receive: bytes_uploaded, total_bytes
             for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
+        headers_factory: optional function to inject headers
 
     Returns:
         The total number of uploaded bytes.
@@ -257,6 +272,7 @@ def upload_file(
             pool=pool,
             md5=md5,
             callback_func=callback_func,
+            headers_factory=headers_factory,
         )
 
     return size
@@ -311,6 +327,7 @@ def upload_fileobj(
     pool: Optional[urllib3.PoolManager] = None,
     md5: Optional[bytes] = None,
     callback_func: Optional[Callable[[int, int], None]] = None,
+    headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
 ) -> int:
     """Upload a file object to a url.
 
@@ -331,6 +348,7 @@ def upload_fileobj(
             should be included in the signing procedure.
         callback_func: optional function used to receive: bytes_uploaded, total_bytes
             for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
+        headers_factory: optional function to inject headers
 
     Returns:
         The total number of uploaded bytes.
@@ -384,6 +402,8 @@ def upload_fileobj(
     }
     if md5 is not None:
         headers["Content-MD5"] = base64.b64encode(md5).decode()
+    if headers_factory is not None:
+        headers.update(headers_factory())
     response = pool.request(
         "PUT",
         url,
@@ -391,7 +411,7 @@ def upload_fileobj(
         headers=headers,
         timeout=DEFAULT_UPLOAD_TIMEOUT if timeout is None else timeout,
     )
-    if response.status != 200:
+    if response.status not in {HTTPStatus.OK, HTTPStatus.CREATED}:
         raise ApiException("Unexpected status", status=response.status)
 
     return file_size

+ 6 - 4
clean_python/api_client/sync_api_provider.py

@@ -29,15 +29,16 @@ class SyncApiProvider:
 
     Args:
         url: The url of the API (with trailing slash)
-        fetch_token: Callable that returns a token for a tenant id
+        headers_factory: Callable that returns headers (for e.g. authorization)
         retries: Total number of retries per request
         backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
+        trailing_slash: Wether to automatically add or remove trailing slashes.
     """
 
     def __init__(
         self,
         url: AnyHttpUrl,
-        fetch_token: Callable[[], Dict[str, str]],
+        headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
         retries: int = 3,
         backoff_factor: float = 1.0,
         trailing_slash: bool = False,
@@ -45,7 +46,7 @@ class SyncApiProvider:
         self._url = str(url)
         if not self._url.endswith("/"):
             self._url += "/"
-        self._fetch_token = fetch_token
+        self._headers_factory = headers_factory
         self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
         self._trailing_slash = trailing_slash
 
@@ -89,7 +90,8 @@ class SyncApiProvider:
             }
             request_kwargs["encode_multipart"] = True
 
-        headers.update(self._fetch_token())
+        if self._headers_factory is not None:
+            headers.update(self._headers_factory())
         return self._pool.request(headers=headers, **request_kwargs)
 
     def request(

+ 20 - 16
clean_python/oauth2/client_credentials.py

@@ -53,15 +53,17 @@ class CCTokenGateway:
         self.timeout = settings.timeout
         self.leeway = settings.leeway
 
-        async def fetch_token():
+        async def headers_factory():
             auth = BasicAuth(settings.client_id, settings.client_secret)
             return {"Authorization": auth.encode()}
 
-        self.provider = ApiProvider(url=settings.token_url, fetch_token=fetch_token)
+        self.provider = ApiProvider(
+            url=settings.token_url, headers_factory=headers_factory
+        )
         # This binds the cache to the CCTokenGateway instance (and not the class)
-        self.cached_fetch_token = alru_cache(self._fetch_token)
+        self.cached_headers_factory = alru_cache(self._headers_factory)
 
-    async def _fetch_token(self) -> str:
+    async def _headers_factory(self) -> str:
         response = await self.provider.request(
             method="POST",
             path="",
@@ -71,11 +73,11 @@ class CCTokenGateway:
         assert response is not None
         return response["access_token"]
 
-    async def fetch_token(self) -> str:
-        token_str = await self.cached_fetch_token()
+    async def headers_factory(self) -> str:
+        token_str = await self.cached_headers_factory()
         if not is_token_usable(token_str, self.leeway):
-            self.cached_fetch_token.cache_clear()
-            token_str = await self.cached_fetch_token()
+            self.cached_headers_factory.cache_clear()
+            token_str = await self.cached_headers_factory()
         return token_str
 
 
@@ -88,15 +90,17 @@ class SyncCCTokenGateway:
         self.timeout = settings.timeout
         self.leeway = settings.leeway
 
-        def fetch_token():
+        def headers_factory():
             auth = BasicAuth(settings.client_id, settings.client_secret)
             return {"Authorization": auth.encode()}
 
-        self.provider = SyncApiProvider(url=settings.token_url, fetch_token=fetch_token)
+        self.provider = SyncApiProvider(
+            url=settings.token_url, headers_factory=headers_factory
+        )
         # This binds the cache to the SyncCCTokenGateway instance (and not the class)
-        self.cached_fetch_token = lru_cache(self._fetch_token)
+        self.cached_headers_factory = lru_cache(self._headers_factory)
 
-    def _fetch_token(self) -> str:
+    def _headers_factory(self) -> str:
         response = self.provider.request(
             method="POST",
             path="",
@@ -106,9 +110,9 @@ class SyncCCTokenGateway:
         assert response is not None
         return response["access_token"]
 
-    def fetch_token(self) -> str:
-        token_str = self.cached_fetch_token()
+    def headers_factory(self) -> str:
+        token_str = self.cached_headers_factory()
         if not is_token_usable(token_str, self.leeway):
-            self.cached_fetch_token.cache_clear()
-            token_str = self.cached_fetch_token()
+            self.cached_headers_factory.cache_clear()
+            token_str = self.cached_headers_factory()
         return token_str

+ 4 - 4
integration_tests/test_int_client_credentials.py

@@ -22,8 +22,8 @@ def gateway(settings) -> CCTokenGateway:
     return CCTokenGateway(settings)
 
 
-async def test_fetch_token(gateway: CCTokenGateway):
-    response = await gateway._fetch_token()
+async def test_headers_factory(gateway: CCTokenGateway):
+    response = await gateway._headers_factory()
     assert is_token_usable(response, 0)
 
 
@@ -32,6 +32,6 @@ def sync_gateway(settings) -> SyncCCTokenGateway:
     return SyncCCTokenGateway(settings)
 
 
-def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
-    response = sync_gateway._fetch_token()
+def test_headers_factory_sync(sync_gateway: SyncCCTokenGateway):
+    response = sync_gateway._headers_factory()
     assert is_token_usable(response, 0)

+ 2 - 2
tests/api_client/test_api_provider.py

@@ -45,7 +45,7 @@ def api_provider(tenant, response) -> ApiProvider:
     with mock.patch.object(ClientSession, "request", new=request):
         api_provider = ApiProvider(
             url="http://testserver/foo/",
-            fetch_token=fake_token,
+            headers_factory=fake_token,
         )
         api_provider._session.request.return_value = response
         yield api_provider
@@ -148,7 +148,7 @@ async def test_error_response(api_provider: ApiProvider, response, status):
 
 
 async def test_no_token(api_provider: ApiProvider):
-    api_provider._fetch_token = no_token
+    api_provider._headers_factory = no_token
     await api_provider.request("GET", "")
     assert api_provider._session.request.call_args[1]["headers"] == {}
 

+ 4 - 2
tests/api_client/test_sync_api_provider.py

@@ -36,7 +36,7 @@ def api_provider(tenant, response) -> SyncApiProvider:
     with mock.patch(MODULE + ".PoolManager"):
         api_provider = SyncApiProvider(
             url="http://testserver/foo/",
-            fetch_token=lambda: {"Authorization": f"Bearer tenant-{ctx.tenant.id}"},
+            headers_factory=lambda: {"Authorization": f"Bearer tenant-{ctx.tenant.id}"},
         )
         api_provider._pool.request.return_value = response
         yield api_provider
@@ -138,7 +138,9 @@ def test_error_response(api_provider: SyncApiProvider, response, status):
 
 @mock.patch(MODULE + ".PoolManager", new=mock.Mock())
 def test_no_token(response, tenant):
-    api_provider = SyncApiProvider(url="http://testserver/foo/", fetch_token=lambda: {})
+    api_provider = SyncApiProvider(
+        url="http://testserver/foo/", headers_factory=lambda: {}
+    )
     api_provider._pool.request.return_value = response
     api_provider.request("GET", "")
     assert api_provider._pool.request.call_args[1]["headers"] == {}

+ 53 - 2
tests/api_client/test_sync_files.py

@@ -111,7 +111,12 @@ def test_download_fileobj_forbidden(pool, responses_single):
 @mock.patch(MODULE + ".download_fileobj")
 def test_download_file(download_fileobj, tmp_path):
     download_file(
-        "http://domain/a.b", tmp_path / "c.d", chunk_size=64, timeout=3.0, pool="foo"
+        "http://domain/a.b",
+        tmp_path / "c.d",
+        chunk_size=64,
+        timeout=3.0,
+        pool="foo",
+        headers_factory="bar",
     )
 
     args, kwargs = download_fileobj.call_args
@@ -122,6 +127,7 @@ def test_download_file(download_fileobj, tmp_path):
     assert kwargs["chunk_size"] == 64
     assert kwargs["timeout"] == 3.0
     assert kwargs["pool"] == "foo"
+    assert kwargs["headers_factory"] == "bar"
 
 
 @mock.patch(MODULE + ".download_fileobj")
@@ -231,7 +237,13 @@ def test_upload_file(upload_fileobj, tmp_path):
         f.write(b"X")
 
     upload_file(
-        "http://domain/a.b", path, chunk_size=1234, timeout=3.0, pool="foo", md5=b"abcd"
+        "http://domain/a.b",
+        path,
+        chunk_size=1234,
+        timeout=3.0,
+        pool="foo",
+        md5=b"abcd",
+        headers_factory="bar",
     )
 
     args, kwargs = upload_fileobj.call_args
@@ -243,6 +255,7 @@ def test_upload_file(upload_fileobj, tmp_path):
     assert kwargs["chunk_size"] == 1234
     assert kwargs["pool"] == "foo"
     assert kwargs["md5"] == b"abcd"
+    assert kwargs["headers_factory"] == "bar"
 
 
 def test_seekable_chunk_iterator():
@@ -255,3 +268,41 @@ def test_seekable_chunk_iterator():
     assert list(body) == []
     set_file_position(body, pos)
     assert list(body) == [data]
+
+
+def test_download_fileobj_with_headers(pool, responses_single):
+    pool.request.side_effect = responses_single
+    download_fileobj(
+        "some-url",
+        io.BytesIO(),
+        chunk_size=64,
+        pool=pool,
+        headers_factory=lambda: {"foo": "bar"},
+    )
+
+    pool.request.assert_called_with(
+        "GET",
+        "some-url",
+        headers={"Range": "bytes=0-63", "foo": "bar"},
+        timeout=5.0,
+    )
+
+
+def test_upload_fileobj_with_headers(pool, fileobj, upload_response):
+    pool.request.return_value = upload_response
+    upload_fileobj(
+        "some-url", fileobj, pool=pool, headers_factory=lambda: {"foo": "bar"}
+    )
+
+    _, kwargs = pool.request.call_args
+    assert kwargs["headers"] == {"Content-Length": "39", "foo": "bar"}
+
+
+def test_upload_fileobj_201_response(pool, fileobj):
+    pool.request.return_value = HTTPResponse(status=201)
+    # no error is raised
+    upload_fileobj(
+        "some-url", fileobj, pool=pool, headers_factory=lambda: {"foo": "bar"}
+    )
+
+    pool.request.assert_called_once()

+ 12 - 12
tests/oauth2/test_client_credentials.py

@@ -55,10 +55,10 @@ def test_is_token_usable(expires_in, leeway, expected):
     assert is_token_usable(token, leeway) is expected
 
 
-async def test_fetch_token(gateway: CCTokenGateway):
+async def test_headers_factory(gateway: CCTokenGateway):
     gateway.provider.request.return_value = {"access_token": "foo"}
 
-    token = await gateway._fetch_token()
+    token = await gateway._headers_factory()
 
     assert token == "foo"
 
@@ -70,18 +70,18 @@ async def test_fetch_token(gateway: CCTokenGateway):
     )
 
 
-async def test_fetch_token_cache(gateway: CCTokenGateway):
+async def test_headers_factory_cache(gateway: CCTokenGateway):
     # empty cache: provider gets called
     token = get_token({})
     gateway.provider.request.return_value = {"access_token": token}
-    actual = await gateway.fetch_token()
+    actual = await gateway.headers_factory()
     assert actual == token
     assert gateway.provider.request.called
 
     gateway.provider.request.reset_mock()
 
     # cache is filled: provider is not called
-    actual = await gateway.fetch_token()
+    actual = await gateway.headers_factory()
     assert actual == token
     assert not gateway.provider.request.called
 
@@ -89,15 +89,15 @@ async def test_fetch_token_cache(gateway: CCTokenGateway):
 
     # token is not usable so it is refreshed:
     with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
-        actual = await gateway.fetch_token()
+        actual = await gateway.headers_factory()
         assert actual == token
         assert gateway.provider.request.called
 
 
-def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
+def test_headers_factory_sync(sync_gateway: SyncCCTokenGateway):
     sync_gateway.provider.request.return_value = {"access_token": "foo"}
 
-    token = sync_gateway._fetch_token()
+    token = sync_gateway._headers_factory()
 
     assert token == "foo"
 
@@ -109,18 +109,18 @@ def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
     )
 
 
-def test_fetch_token_sync_cache(sync_gateway: SyncCCTokenGateway):
+def test_headers_factory_sync_cache(sync_gateway: SyncCCTokenGateway):
     # empty cache: provider gets called
     token = get_token({})
     sync_gateway.provider.request.return_value = {"access_token": token}
-    actual = sync_gateway.fetch_token()
+    actual = sync_gateway.headers_factory()
     assert actual == token
     assert sync_gateway.provider.request.called
 
     sync_gateway.provider.request.reset_mock()
 
     # cache is filled: provider is not called
-    actual = sync_gateway.fetch_token()
+    actual = sync_gateway.headers_factory()
     assert actual == token
     assert not sync_gateway.provider.request.called
 
@@ -128,6 +128,6 @@ def test_fetch_token_sync_cache(sync_gateway: SyncCCTokenGateway):
 
     # token is not usable so it is refreshed:
     with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
-        actual = sync_gateway.fetch_token()
+        actual = sync_gateway.headers_factory()
         assert actual == token
         assert sync_gateway.provider.request.called