test_api_provider.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from asyncio.exceptions import TimeoutError
  2. from http import HTTPStatus
  3. from unittest import mock
  4. import pytest
  5. from aiohttp import ClientError
  6. from aiohttp import ClientSession
  7. from clean_python import Conflict
  8. from clean_python import ctx
  9. from clean_python import Tenant
  10. from clean_python.api_client import ApiException
  11. from clean_python.api_client import ApiProvider
  12. MODULE = "clean_python.api_client.api_provider"
  13. async def fake_token():
  14. return {"Authorization": f"Bearer tenant-{ctx.tenant.id}"}
  15. async def no_token():
  16. return {}
  17. @pytest.fixture
  18. def tenant() -> Tenant:
  19. ctx.tenant = Tenant(id=2, name="")
  20. yield ctx.tenant
  21. ctx.tenant = None
  22. @pytest.fixture
  23. def response():
  24. # this mocks the aiohttp.ClientResponse:
  25. response = mock.Mock()
  26. response.status = int(HTTPStatus.OK)
  27. response.headers = {"Content-Type": "application/json"}
  28. response.json = mock.AsyncMock(return_value={"foo": 2})
  29. response.read = mock.AsyncMock()
  30. return response
  31. @pytest.fixture
  32. def api_provider_no_mock() -> mock.AsyncMock:
  33. return ApiProvider(
  34. url="http://testserver/foo/",
  35. headers_factory=fake_token,
  36. retries=0,
  37. )
  38. @pytest.fixture
  39. def request_m() -> mock.AsyncMock:
  40. request = mock.AsyncMock()
  41. with mock.patch.object(ClientSession, "request", new=request):
  42. yield request
  43. @pytest.fixture
  44. def api_provider(api_provider_no_mock, tenant, response, request_m) -> ApiProvider:
  45. request_m.return_value = response
  46. return api_provider_no_mock
  47. async def test_get(api_provider: ApiProvider, request_m):
  48. actual = await api_provider.request("GET", "")
  49. assert request_m.call_count == 1
  50. assert request_m.call_args[1] == dict(
  51. method="GET",
  52. url="http://testserver/foo",
  53. headers={"Authorization": "Bearer tenant-2"},
  54. timeout=5.0,
  55. data=None,
  56. json=None,
  57. )
  58. assert actual == {"foo": 2}
  59. async def test_post_json(api_provider: ApiProvider, response, request_m):
  60. response.status == int(HTTPStatus.CREATED)
  61. request_m.return_value = response
  62. actual = await api_provider.request("POST", "bar", json={"foo": 2})
  63. assert request_m.call_count == 1
  64. assert request_m.call_args[1] == dict(
  65. method="POST",
  66. url="http://testserver/foo/bar",
  67. data=None,
  68. json={"foo": 2},
  69. headers={
  70. "Authorization": "Bearer tenant-2",
  71. },
  72. timeout=5.0,
  73. )
  74. assert actual == {"foo": 2}
  75. @pytest.mark.parametrize(
  76. "path,params,expected_url",
  77. [
  78. ("", None, "http://testserver/foo"),
  79. ("bar", None, "http://testserver/foo/bar"),
  80. ("bar/", None, "http://testserver/foo/bar"),
  81. ("", {"a": 2}, "http://testserver/foo?a=2"),
  82. ("bar", {"a": 2}, "http://testserver/foo/bar?a=2"),
  83. ("bar/", {"a": 2}, "http://testserver/foo/bar?a=2"),
  84. ("", {"a": [1, 2]}, "http://testserver/foo?a=1&a=2"),
  85. ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
  86. ],
  87. )
  88. async def test_url(api_provider: ApiProvider, path, params, expected_url, request_m):
  89. await api_provider.request("GET", path, params=params)
  90. assert request_m.call_args[1]["url"] == expected_url
  91. async def test_timeout(api_provider: ApiProvider, request_m):
  92. await api_provider.request("POST", "bar", timeout=2.1)
  93. assert request_m.call_args[1]["timeout"] == 2.1
  94. @pytest.mark.parametrize(
  95. "status", [HTTPStatus.OK, HTTPStatus.NOT_FOUND, HTTPStatus.INTERNAL_SERVER_ERROR]
  96. )
  97. async def test_unexpected_content_type(api_provider: ApiProvider, response, status):
  98. response.status = int(status)
  99. response.headers["Content-Type"] = "text/plain"
  100. with pytest.raises(ApiException) as e:
  101. await api_provider.request("GET", "bar")
  102. assert e.value.status is status
  103. assert str(e.value) == f"{status}: Unexpected content type 'text/plain'"
  104. async def test_json_variant_content_type(api_provider: ApiProvider, response):
  105. response.headers["Content-Type"] = "application/something+json"
  106. actual = await api_provider.request("GET", "bar")
  107. assert actual == {"foo": 2}
  108. async def test_no_content(api_provider: ApiProvider, response):
  109. response.status = int(HTTPStatus.NO_CONTENT)
  110. response.headers = {}
  111. actual = await api_provider.request("DELETE", "bar/2")
  112. assert actual is None
  113. @pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND])
  114. async def test_error_response(api_provider: ApiProvider, response, status):
  115. response.status = int(status)
  116. with pytest.raises(ApiException) as e:
  117. await api_provider.request("GET", "bar")
  118. assert e.value.status is status
  119. assert str(e.value) == str(int(status)) + ": {'foo': 2}"
  120. async def test_no_token(api_provider: ApiProvider, request_m):
  121. api_provider._headers_factory = no_token
  122. await api_provider.request("GET", "")
  123. assert request_m.call_args[1]["headers"] == {}
  124. @pytest.mark.parametrize(
  125. "path,trailing_slash,expected",
  126. [
  127. ("bar", False, "bar"),
  128. ("bar", True, "bar/"),
  129. ("bar/", False, "bar"),
  130. ("bar/", True, "bar/"),
  131. ],
  132. )
  133. async def test_trailing_slash(
  134. api_provider: ApiProvider, path, trailing_slash, expected, request_m
  135. ):
  136. api_provider._trailing_slash = trailing_slash
  137. await api_provider.request("GET", path)
  138. assert request_m.call_args[1]["url"] == "http://testserver/foo/" + expected
  139. async def test_conflict(api_provider: ApiProvider, response):
  140. response.status = HTTPStatus.CONFLICT
  141. with pytest.raises(Conflict):
  142. await api_provider.request("GET", "bar")
  143. async def test_conflict_with_message(api_provider: ApiProvider, response):
  144. response.status = HTTPStatus.CONFLICT
  145. response.json.return_value = {"message": "foo"}
  146. with pytest.raises(Conflict, match="foo"):
  147. await api_provider.request("GET", "bar")
  148. async def test_custom_header(api_provider: ApiProvider, request_m):
  149. await api_provider.request("POST", "bar", headers={"foo": "bar"})
  150. assert request_m.call_args[1]["headers"] == {
  151. "foo": "bar",
  152. **(await api_provider._headers_factory()),
  153. }
  154. async def test_custom_header_precedes(api_provider: ApiProvider, request_m):
  155. await api_provider.request("POST", "bar", headers={"Authorization": "bar"})
  156. assert request_m.call_args[1]["headers"]["Authorization"] == "bar"
  157. async def test_session_closed(api_provider: ApiProvider, request_m):
  158. with mock.patch.object(
  159. ClientSession, "close", new_callable=mock.AsyncMock
  160. ) as close_m:
  161. await api_provider.request("GET", "")
  162. close_m.assert_awaited_once()
  163. @pytest.fixture
  164. def retry_provider():
  165. return ApiProvider(url="http://testserver/foo/", retries=1, backoff_factor=0.001)
  166. @pytest.fixture
  167. def error_response():
  168. # this mocks the aiohttp.ClientResponse:
  169. response = mock.Mock()
  170. response.status = int(HTTPStatus.SERVICE_UNAVAILABLE)
  171. response.headers = {"Content-Type": "text/html"}
  172. response.read = mock.AsyncMock()
  173. return response
  174. @pytest.mark.parametrize("error_cls", [ClientError, TimeoutError])
  175. @mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
  176. async def test_retry_client_error(
  177. request_m, retry_provider: ApiProvider, error_cls, response
  178. ):
  179. request_m.side_effect = (error_cls(), response)
  180. actual = await retry_provider.request("GET", "")
  181. assert request_m.call_count == 2
  182. assert actual == {"foo": 2}
  183. @mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
  184. async def test_retry_client_error_too_many(request_m, retry_provider: ApiProvider):
  185. request_m.side_effect = (ClientError("bar"), ClientError("foo"))
  186. with pytest.raises(ClientError, match="foo"):
  187. await retry_provider.request("GET", "")
  188. assert request_m.call_count == 2
  189. @pytest.mark.parametrize("error_code", [429, 500, 502, 503, 504])
  190. @mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
  191. async def test_retry_error_response(
  192. request_m, retry_provider: ApiProvider, error_code: int, response, error_response
  193. ):
  194. error_response.status = error_code
  195. request_m.side_effect = (error_response, response)
  196. actual = await retry_provider.request("GET", "")
  197. assert request_m.call_count == 2
  198. assert actual == {"foo": 2}
  199. @mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
  200. async def test_retry_error_response_too_many(
  201. request_m, retry_provider: ApiProvider, error_response
  202. ):
  203. request_m.return_value = error_response
  204. with pytest.raises(ApiException) as e:
  205. await retry_provider.request("GET", "")
  206. assert request_m.call_count == 2
  207. assert e.value.status == 503
  208. @mock.patch.object(ClientSession, "request", new_callable=mock.AsyncMock)
  209. async def test_no_retry_on_post(request_m, retry_provider: ApiProvider):
  210. request_m.side_effect = ClientError()
  211. with pytest.raises(ClientError):
  212. await retry_provider.request("POST", "")
  213. assert request_m.call_count == 1