|
@@ -1,7 +1,9 @@
|
|
|
+from asyncio.exceptions import TimeoutError
|
|
|
from http import HTTPStatus
|
|
|
from unittest import mock
|
|
|
|
|
|
import pytest
|
|
|
+from aiohttp import ClientError
|
|
|
from aiohttp import ClientSession
|
|
|
|
|
|
from clean_python import Conflict
|
|
@@ -44,6 +46,7 @@ def api_provider_no_mock() -> mock.AsyncMock:
|
|
|
return ApiProvider(
|
|
|
url="http://testserver/foo/",
|
|
|
headers_factory=fake_token,
|
|
|
+ retries=0,
|
|
|
)
|
|
|
|
|
|
|
|
@@ -215,3 +218,78 @@ async def test_session_closed(api_provider: ApiProvider, request_m):
|
|
|
await api_provider.request("GET", "")
|
|
|
|
|
|
close_m.assert_awaited_once()
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def retry_provider():
|
|
|
+ return ApiProvider(url="http://testserver/foo/", retries=1, backoff_factor=0.001)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def error_response():
|
|
|
+ # this mocks the aiohttp.ClientResponse:
|
|
|
+ response = mock.Mock()
|
|
|
+ response.status = int(HTTPStatus.SERVICE_UNAVAILABLE)
|
|
|
+ response.headers = {"Content-Type": "text/html"}
|
|
|
+ response.read = mock.AsyncMock()
|
|
|
+ return response
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize("error_cls", [ClientError, TimeoutError])
|
|
|
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
|
|
|
+async def test_retry_client_error(
|
|
|
+ request_m, retry_provider: ApiProvider, error_cls, response
|
|
|
+):
|
|
|
+ request_m.side_effect = (error_cls(), response)
|
|
|
+
|
|
|
+ actual = await retry_provider.request("GET", "")
|
|
|
+
|
|
|
+ assert request_m.call_count == 2
|
|
|
+ assert actual == {"foo": 2}
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
|
|
|
+async def test_retry_client_error_too_many(request_m, retry_provider: ApiProvider):
|
|
|
+ request_m.side_effect = (ClientError("bar"), ClientError("foo"))
|
|
|
+
|
|
|
+ with pytest.raises(ClientError, match="foo"):
|
|
|
+ await retry_provider.request("GET", "")
|
|
|
+
|
|
|
+ assert request_m.call_count == 2
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize("error_code", [429, 500, 502, 503, 504])
|
|
|
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
|
|
|
+async def test_retry_error_response(
|
|
|
+ request_m, retry_provider: ApiProvider, error_code: int, response, error_response
|
|
|
+):
|
|
|
+ error_response.status = error_code
|
|
|
+ request_m.side_effect = (error_response, response)
|
|
|
+
|
|
|
+ actual = await retry_provider.request("GET", "")
|
|
|
+
|
|
|
+ assert request_m.call_count == 2
|
|
|
+ assert actual == {"foo": 2}
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
|
|
|
+async def test_retry_error_response_too_many(
|
|
|
+ request_m, retry_provider: ApiProvider, error_response
|
|
|
+):
|
|
|
+ request_m.return_value = error_response
|
|
|
+
|
|
|
+ with pytest.raises(ApiException) as e:
|
|
|
+ await retry_provider.request("GET", "")
|
|
|
+
|
|
|
+ assert request_m.call_count == 2
|
|
|
+ assert e.value.status == 503
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
|
|
|
+async def test_no_retry_on_post(request_m, retry_provider: ApiProvider):
|
|
|
+ request_m.side_effect = ClientError()
|
|
|
+
|
|
|
+ with pytest.raises(ClientError):
|
|
|
+ await retry_provider.request("POST", "")
|
|
|
+
|
|
|
+ assert request_m.call_count == 1
|