Prechádzať zdrojové kódy

Sharpen api client retry policy (#48)

Casper van der Wel 11 mesiacov pred
rodič
commit
01314bf200

+ 5 - 1
CHANGES.md

@@ -4,7 +4,11 @@
 0.9.5 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- SyncApiProvider: also retry when the Retry-After response header is missing.
+
+- ApiProvider: (sync and async) retry on all methods except POST.
+
+- ApiProvider: (sync and async) retry on 429, 500, 502, 503, 504.
 
 
 0.9.4 (2023-12-07)

+ 21 - 7
clean_python/api_client/api_provider.py

@@ -27,7 +27,20 @@ from .response import Response
 __all__ = ["ApiProvider", "FileFormPost"]
 
 
-RETRY_STATUSES = frozenset({413, 429, 503})  # like in urllib3
+# Retry on 429 and all 5xx errors (because they are mostly temporary)
+RETRY_STATUSES = frozenset(
+    {
+        HTTPStatus.TOO_MANY_REQUESTS,
+        HTTPStatus.INTERNAL_SERVER_ERROR,
+        HTTPStatus.BAD_GATEWAY,
+        HTTPStatus.SERVICE_UNAVAILABLE,
+        HTTPStatus.GATEWAY_TIMEOUT,
+    }
+)
+# PATCH is strictly not idempotent, because you could do advanced
+# JSON operations like 'add an array element'. mostly idempotent.
+# However we never do that and we always make PATCH idempotent.
+RETRY_METHODS = frozenset(["HEAD", "GET", "PATCH", "PUT", "DELETE", "OPTIONS", "TRACE"])
 
 
 def is_success(status: HTTPStatus) -> bool:
@@ -109,7 +122,7 @@ class ApiProvider:
         if not self._url.endswith("/"):
             self._url += "/"
         self._headers_factory = headers_factory
-        assert retries > 0
+        assert retries >= 0
         self._retries = retries
         self._backoff_factor = backoff_factor
         self._trailing_slash = trailing_slash
@@ -150,7 +163,8 @@ class ApiProvider:
             actual_headers.update(await self._headers_factory())
         if headers:
             actual_headers.update(headers)
-        for attempt in range(self._retries):
+        retries = self._retries if method.upper() in RETRY_METHODS else 0
+        for attempt in range(retries + 1):
             if attempt > 0:
                 backoff = self._backoff_factor * 2 ** (attempt - 1)
                 await asyncio.sleep(backoff)
@@ -160,13 +174,13 @@ class ApiProvider:
                     response = await session.request(
                         headers=actual_headers, **request_kwargs
                     )
+                    if response.status in RETRY_STATUSES:
+                        continue
                     await response.read()
+                    return response
             except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
-                if attempt == self._retries - 1:
+                if attempt == retries:
                     raise  # propagate ClientError in case no retries left
-            else:
-                if response.status not in RETRY_STATUSES:
-                    return response  # on all non-retry statuses: return response
 
         return response  # retries exceeded; return the (possibly error) response
 

+ 10 - 1
clean_python/api_client/sync_api_provider.py

@@ -16,6 +16,8 @@ from .api_provider import check_exception
 from .api_provider import FileFormPost
 from .api_provider import is_json_content_type
 from .api_provider import join
+from .api_provider import RETRY_METHODS
+from .api_provider import RETRY_STATUSES
 from .exceptions import ApiException
 from .response import Response
 
@@ -47,7 +49,14 @@ class SyncApiProvider:
         if not self._url.endswith("/"):
             self._url += "/"
         self._headers_factory = headers_factory
-        self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
+        self._pool = PoolManager(
+            retries=Retry(
+                retries,
+                backoff_factor=backoff_factor,
+                status_forcelist=RETRY_STATUSES,
+                allowed_methods=RETRY_METHODS,
+            )
+        )
         self._trailing_slash = trailing_slash
 
     def _request(

+ 20 - 2
integration_tests/conftest.py

@@ -3,6 +3,9 @@
 import asyncio
 import multiprocessing
 import os
+import time
+from urllib.error import URLError
+from urllib.request import urlopen
 
 import pytest
 import uvicorn
@@ -43,11 +46,26 @@ async def s3_url():
     return os.environ.get("S3_URL", "http://localhost:9000")
 
 
+def wait_until_url_available(url: str, max_tries=10, interval=0.1):
+    # wait for the server to be ready
+    for _ in range(max_tries):
+        try:
+            urlopen(url)
+        except URLError:
+            time.sleep(interval)
+            continue
+        else:
+            break
+
+
 @pytest.fixture(scope="session")
 async def fastapi_example_app():
     port = int(os.environ.get("API_PORT", "8005"))
     config = uvicorn.Config("fastapi_example:app", host="0.0.0.0", port=port)
     p = multiprocessing.Process(target=uvicorn.Server(config).run)
     p.start()
-    yield f"http://localhost:{port}"
-    p.terminate()
+    try:
+        wait_until_url_available(f"http://localhost:{port}/docs")
+        yield f"http://localhost:{port}"
+    finally:
+        p.terminate()

+ 78 - 0
tests/api_client/test_api_provider.py

@@ -1,7 +1,9 @@
+from asyncio.exceptions import TimeoutError
 from http import HTTPStatus
 from unittest import mock
 
 import pytest
+from aiohttp import ClientError
 from aiohttp import ClientSession
 
 from clean_python import Conflict
@@ -44,6 +46,7 @@ def api_provider_no_mock() -> mock.AsyncMock:
     return ApiProvider(
         url="http://testserver/foo/",
         headers_factory=fake_token,
+        retries=0,
     )
 
 
@@ -215,3 +218,78 @@ async def test_session_closed(api_provider: ApiProvider, request_m):
         await api_provider.request("GET", "")
 
     close_m.assert_awaited_once()
+
+
+@pytest.fixture
+def retry_provider():
+    return ApiProvider(url="http://testserver/foo/", retries=1, backoff_factor=0.001)
+
+
+@pytest.fixture
+def error_response():
+    # this mocks the aiohttp.ClientResponse:
+    response = mock.Mock()
+    response.status = int(HTTPStatus.SERVICE_UNAVAILABLE)
+    response.headers = {"Content-Type": "text/html"}
+    response.read = mock.AsyncMock()
+    return response
+
+
+@pytest.mark.parametrize("error_cls", [ClientError, TimeoutError])
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
+async def test_retry_client_error(
+    request_m, retry_provider: ApiProvider, error_cls, response
+):
+    request_m.side_effect = (error_cls(), response)
+
+    actual = await retry_provider.request("GET", "")
+
+    assert request_m.call_count == 2
+    assert actual == {"foo": 2}
+
+
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
+async def test_retry_client_error_too_many(request_m, retry_provider: ApiProvider):
+    request_m.side_effect = (ClientError("bar"), ClientError("foo"))
+
+    with pytest.raises(ClientError, match="foo"):
+        await retry_provider.request("GET", "")
+
+    assert request_m.call_count == 2
+
+
+@pytest.mark.parametrize("error_code", [429, 500, 502, 503, 504])
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
+async def test_retry_error_response(
+    request_m, retry_provider: ApiProvider, error_code: int, response, error_response
+):
+    error_response.status = error_code
+    request_m.side_effect = (error_response, response)
+
+    actual = await retry_provider.request("GET", "")
+
+    assert request_m.call_count == 2
+    assert actual == {"foo": 2}
+
+
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
+async def test_retry_error_response_too_many(
+    request_m, retry_provider: ApiProvider, error_response
+):
+    request_m.return_value = error_response
+
+    with pytest.raises(ApiException) as e:
+        await retry_provider.request("GET", "")
+
+    assert request_m.call_count == 2
+    assert e.value.status == 503
+
+
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
+async def test_no_retry_on_post(request_m, retry_provider: ApiProvider):
+    request_m.side_effect = ClientError()
+
+    with pytest.raises(ClientError):
+        await retry_provider.request("POST", "")
+
+    assert request_m.call_count == 1