Explorar o código

Let ApiProvider raise Conflict on 409 (#39)

Casper van der Wel hai 1 ano
pai
achega
9fe72e224f

+ 10 - 4
clean_python/api_client/api_provider.py

@@ -17,6 +17,7 @@ from aiohttp import ClientSession
 from pydantic import AnyHttpUrl
 from pydantic import field_validator
 
+from clean_python import Conflict
 from clean_python import Json
 from clean_python import ValueObject
 
@@ -34,6 +35,13 @@ def is_success(status: HTTPStatus) -> bool:
     return (int(status) // 100) == 2
 
 
+def check_exception(status: HTTPStatus, body: Json) -> None:
+    if status == HTTPStatus.CONFLICT:
+        raise Conflict(body.get("message", str(body)))
+    elif not is_success(status):
+        raise ApiException(body, status=status)
+
+
 JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")
 
 
@@ -167,10 +175,8 @@ class ApiProvider:
                 f"Unexpected content type '{content_type}'", status=status
             )
         body = await response.json()
-        if is_success(status):
-            return body
-        else:
-            raise ApiException(body, status=status)
+        check_exception(status, body)
+        return body
 
     async def request_raw(
         self,

+ 3 - 5
clean_python/api_client/sync_api_provider.py

@@ -12,9 +12,9 @@ from urllib3 import Retry
 from clean_python import Json
 
 from .api_provider import add_query_params
+from .api_provider import check_exception
 from .api_provider import FileFormPost
 from .api_provider import is_json_content_type
-from .api_provider import is_success
 from .api_provider import join
 from .exceptions import ApiException
 from .response import Response
@@ -112,10 +112,8 @@ class SyncApiProvider:
                 f"Unexpected content type '{content_type}'", status=status
             )
         body = json_lib.loads(response.data.decode())
-        if is_success(status):
-            return body
-        else:
-            raise ApiException(body, status=status)
+        check_exception(status, body)
+        return body
 
     def request_raw(
         self,

+ 16 - 0
tests/api_client/test_api_provider.py

@@ -4,6 +4,7 @@ from unittest import mock
 import pytest
 from aiohttp import ClientSession
 
+from clean_python import Conflict
 from clean_python import ctx
 from clean_python import Tenant
 from clean_python.api_client import ApiException
@@ -171,3 +172,18 @@ async def test_trailing_slash(
         api_provider._session.request.call_args[1]["url"]
         == "http://testserver/foo/" + expected
     )
+
+
+async def test_conflict(api_provider: ApiProvider, response):
+    response.status = HTTPStatus.CONFLICT
+
+    with pytest.raises(Conflict):
+        await api_provider.request("GET", "bar")
+
+
+async def test_conflict_with_message(api_provider: ApiProvider, response):
+    response.status = HTTPStatus.CONFLICT
+    response.json.return_value = {"message": "foo"}
+
+    with pytest.raises(Conflict, match="foo"):
+        await api_provider.request("GET", "bar")

+ 16 - 0
tests/api_client/test_sync_api_provider.py

@@ -5,6 +5,7 @@ from unittest import mock
 
 import pytest
 
+from clean_python import Conflict
 from clean_python import ctx
 from clean_python import Tenant
 from clean_python.api_client import ApiException
@@ -203,3 +204,18 @@ def test_post_file_with_fields(api_provider: SyncApiProvider):
         timeout=5.0,
         encode_multipart=True,
     )
+
+
+def test_conflict(api_provider: SyncApiProvider, response):
+    response.status = HTTPStatus.CONFLICT
+
+    with pytest.raises(Conflict):
+        api_provider.request("GET", "bar")
+
+
+def test_conflict_with_message(api_provider: SyncApiProvider, response):
+    response.status = HTTPStatus.CONFLICT
+    response.json.return_value = {"message": "foo"}
+
+    with pytest.raises(Conflict, match="foo"):
+        api_provider.request("GET", "bar")