|  | @@ -1,7 +1,9 @@
 | 
	
		
			
				|  |  | +from asyncio.exceptions import TimeoutError
 | 
	
		
			
				|  |  |  from http import HTTPStatus
 | 
	
		
			
				|  |  |  from unittest import mock
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import pytest
 | 
	
		
			
				|  |  | +from aiohttp import ClientError
 | 
	
		
			
				|  |  |  from aiohttp import ClientSession
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from clean_python import Conflict
 | 
	
	
		
			
				|  | @@ -44,6 +46,7 @@ def api_provider_no_mock() -> mock.AsyncMock:
 | 
	
		
			
				|  |  |      return ApiProvider(
 | 
	
		
			
				|  |  |          url="http://testserver/foo/",
 | 
	
		
			
				|  |  |          headers_factory=fake_token,
 | 
	
		
			
				|  |  | +        retries=0,
 | 
	
		
			
				|  |  |      )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -215,3 +218,78 @@ async def test_session_closed(api_provider: ApiProvider, request_m):
 | 
	
		
			
				|  |  |          await api_provider.request("GET", "")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      close_m.assert_awaited_once()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@pytest.fixture
 | 
	
		
			
				|  |  | +def retry_provider():
 | 
	
		
			
				|  |  | +    return ApiProvider(url="http://testserver/foo/", retries=1, backoff_factor=0.001)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@pytest.fixture
 | 
	
		
			
				|  |  | +def error_response():
 | 
	
		
			
				|  |  | +    # this mocks the aiohttp.ClientResponse:
 | 
	
		
			
				|  |  | +    response = mock.Mock()
 | 
	
		
			
				|  |  | +    response.status = int(HTTPStatus.SERVICE_UNAVAILABLE)
 | 
	
		
			
				|  |  | +    response.headers = {"Content-Type": "text/html"}
 | 
	
		
			
				|  |  | +    response.read = mock.AsyncMock()
 | 
	
		
			
				|  |  | +    return response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@pytest.mark.parametrize("error_cls", [ClientError, TimeoutError])
 | 
	
		
			
				|  |  | +@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
 | 
	
		
			
				|  |  | +async def test_retry_client_error(
 | 
	
		
			
				|  |  | +    request_m, retry_provider: ApiProvider, error_cls, response
 | 
	
		
			
				|  |  | +):
 | 
	
		
			
				|  |  | +    request_m.side_effect = (error_cls(), response)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    actual = await retry_provider.request("GET", "")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    assert request_m.call_count == 2
 | 
	
		
			
				|  |  | +    assert actual == {"foo": 2}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
 | 
	
		
			
				|  |  | +async def test_retry_client_error_too_many(request_m, retry_provider: ApiProvider):
 | 
	
		
			
				|  |  | +    request_m.side_effect = (ClientError("bar"), ClientError("foo"))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    with pytest.raises(ClientError, match="foo"):
 | 
	
		
			
				|  |  | +        await retry_provider.request("GET", "")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    assert request_m.call_count == 2
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@pytest.mark.parametrize("error_code", [429, 500, 502, 503, 504])
 | 
	
		
			
				|  |  | +@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
 | 
	
		
			
				|  |  | +async def test_retry_error_response(
 | 
	
		
			
				|  |  | +    request_m, retry_provider: ApiProvider, error_code: int, response, error_response
 | 
	
		
			
				|  |  | +):
 | 
	
		
			
				|  |  | +    error_response.status = error_code
 | 
	
		
			
				|  |  | +    request_m.side_effect = (error_response, response)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    actual = await retry_provider.request("GET", "")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    assert request_m.call_count == 2
 | 
	
		
			
				|  |  | +    assert actual == {"foo": 2}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
 | 
	
		
			
				|  |  | +async def test_retry_error_response_too_many(
 | 
	
		
			
				|  |  | +    request_m, retry_provider: ApiProvider, error_response
 | 
	
		
			
				|  |  | +):
 | 
	
		
			
				|  |  | +    request_m.return_value = error_response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    with pytest.raises(ApiException) as e:
 | 
	
		
			
				|  |  | +        await retry_provider.request("GET", "")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    assert request_m.call_count == 2
 | 
	
		
			
				|  |  | +    assert e.value.status == 503
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
 | 
	
		
			
				|  |  | +async def test_no_retry_on_post(request_m, retry_provider: ApiProvider):
 | 
	
		
			
				|  |  | +    request_m.side_effect = ClientError()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    with pytest.raises(ClientError):
 | 
	
		
			
				|  |  | +        await retry_provider.request("POST", "")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    assert request_m.call_count == 1
 |