Pārlūkot izejas kodu

ApiProvider should provide with and without trailing slash (#24)

Casper van der Wel 1 gadu atpakaļ
vecāks
revīzija
74d208c69d

+ 1 - 1
CHANGES.md

@@ -4,7 +4,7 @@
 0.6.8 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Add `trailing_slash` option to `ApiProvider`.
 
 
 0.6.7 (2023-10-09)

+ 9 - 3
clean_python/api_client/api_provider.py

@@ -39,12 +39,14 @@ def is_json_content_type(content_type: Optional[str]) -> bool:
     return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))
 
 
-def join(url: str, path: str) -> str:
+def join(url: str, path: str, trailing_slash: bool = False) -> str:
     """Results in a full url without trailing slash"""
     assert url.endswith("/")
     assert not path.startswith("/")
     result = urljoin(url, path)
-    if result.endswith("/"):
+    if trailing_slash and not result.endswith("/"):
+        result = result + "/"
+    elif not trailing_slash and result.endswith("/"):
         result = result[:-1]
     return result
 
@@ -73,6 +75,7 @@ class ApiProvider:
         fetch_token: Callable[[], Awaitable[Dict[str, str]]],
         retries: int = 3,
         backoff_factor: float = 1.0,
+        trailing_slash: bool = False,
     ):
         self._url = str(url)
         if not self._url.endswith("/"):
@@ -81,6 +84,7 @@ class ApiProvider:
         assert retries > 0
         self._retries = retries
         self._backoff_factor = backoff_factor
+        self._trailing_slash = trailing_slash
         self._session = ClientSession()
 
     async def _request_with_retry(
@@ -94,7 +98,9 @@ class ApiProvider:
     ) -> ClientResponse:
         request_kwargs = {
             "method": method,
-            "url": add_query_params(join(self._url, quote(path)), params),
+            "url": add_query_params(
+                join(self._url, quote(path), self._trailing_slash), params
+            ),
             "timeout": timeout,
             "json": json,
             "data": fields,

+ 5 - 1
clean_python/api_client/sync_api_provider.py

@@ -39,12 +39,14 @@ class SyncApiProvider:
         fetch_token: Callable[[], Dict[str, str]],
         retries: int = 3,
         backoff_factor: float = 1.0,
+        trailing_slash: bool = False,
     ):
         self._url = str(url)
         if not self._url.endswith("/"):
             self._url += "/"
         self._fetch_token = fetch_token
         self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
+        self._trailing_slash = trailing_slash
 
     def _request(
         self,
@@ -58,7 +60,9 @@ class SyncApiProvider:
         headers = {}
         request_kwargs = {
             "method": method,
-            "url": add_query_params(join(self._url, quote(path)), params),
+            "url": add_query_params(
+                join(self._url, quote(path), self._trailing_slash), params
+            ),
             "timeout": timeout,
         }
         # for urllib3<2, we dump json ourselves

+ 21 - 0
tests/api_client/test_api_provider.py

@@ -150,3 +150,24 @@ async def test_no_token(api_provider: ApiProvider):
     api_provider._fetch_token = no_token
     await api_provider.request("GET", "")
     assert api_provider._session.request.call_args[1]["headers"] == {}
+
+
+@pytest.mark.parametrize(
+    "path,trailing_slash,expected",
+    [
+        ("bar", False, "bar"),
+        ("bar", True, "bar/"),
+        ("bar/", False, "bar"),
+        ("bar/", True, "bar/"),
+    ],
+)
+async def test_trailing_slash(
+    api_provider: ApiProvider, path, trailing_slash, expected
+):
+    api_provider._trailing_slash = trailing_slash
+    await api_provider.request("GET", path)
+
+    assert (
+        api_provider._session.request.call_args[1]["url"]
+        == "http://testserver/foo/" + expected
+    )

+ 19 - 0
tests/api_client/test_sync_api_provider.py

@@ -140,3 +140,22 @@ def test_no_token(response, tenant):
     api_provider._pool.request.return_value = response
     api_provider.request("GET", "")
     assert api_provider._pool.request.call_args[1]["headers"] == {}
+
+
+@pytest.mark.parametrize(
+    "path,trailing_slash,expected",
+    [
+        ("bar", False, "bar"),
+        ("bar", True, "bar/"),
+        ("bar/", False, "bar"),
+        ("bar/", True, "bar/"),
+    ],
+)
+def test_trailing_slash(api_provider: SyncApiProvider, path, trailing_slash, expected):
+    api_provider._trailing_slash = trailing_slash
+    api_provider.request("GET", path)
+
+    assert (
+        api_provider._pool.request.call_args[1]["url"]
+        == "http://testserver/foo/" + expected
+    )