test_api_provider.py 6.4 KB


  1. from http import HTTPStatus
  2. from unittest import mock
  3. import pytest
  4. from aiohttp import ClientSession
  5. from clean_python import Conflict
  6. from clean_python import ctx
  7. from clean_python import Tenant
  8. from clean_python.api_client import ApiException
  9. from clean_python.api_client import ApiProvider
  10. MODULE = "clean_python.api_client.api_provider"
  11. async def fake_token():
  12. return {"Authorization": f"Bearer tenant-{ctx.tenant.id}"}
  13. async def no_token():
  14. return {}
  15. @pytest.fixture
  16. def tenant() -> Tenant:
  17. ctx.tenant = Tenant(id=2, name="")
  18. yield ctx.tenant
  19. ctx.tenant = None
  20. @pytest.fixture
  21. def response():
  22. # this mocks the aiohttp.ClientResponse:
  23. response = mock.Mock()
  24. response.status = int(HTTPStatus.OK)
  25. response.headers = {"Content-Type": "application/json"}
  26. response.json = mock.AsyncMock(return_value={"foo": 2})
  27. response.read = mock.AsyncMock()
  28. return response
  29. @pytest.fixture
  30. def api_provider_no_mock() -> mock.AsyncMock:
  31. return ApiProvider(
  32. url="http://testserver/foo/",
  33. headers_factory=fake_token,
  34. )
  35. @pytest.fixture
  36. def request_m() -> mock.AsyncMock:
  37. request = mock.AsyncMock()
  38. with mock.patch.object(ClientSession, "request", new=request):
  39. yield request
  40. @pytest.fixture
  41. def api_provider(api_provider_no_mock, tenant, response, request_m) -> ApiProvider:
  42. request_m.return_value = response
  43. return api_provider_no_mock
  44. async def test_get(api_provider: ApiProvider, request_m):
  45. actual = await api_provider.request("GET", "")
  46. assert request_m.call_count == 1
  47. assert request_m.call_args[1] == dict(
  48. method="GET",
  49. url="http://testserver/foo",
  50. headers={"Authorization": "Bearer tenant-2"},
  51. timeout=5.0,
  52. data=None,
  53. json=None,
  54. )
  55. assert actual == {"foo": 2}
  56. async def test_post_json(api_provider: ApiProvider, response, request_m):
  57. response.status == int(HTTPStatus.CREATED)
  58. request_m.return_value = response
  59. actual = await api_provider.request("POST", "bar", json={"foo": 2})
  60. assert request_m.call_count == 1
  61. assert request_m.call_args[1] == dict(
  62. method="POST",
  63. url="http://testserver/foo/bar",
  64. data=None,
  65. json={"foo": 2},
  66. headers={
  67. "Authorization": "Bearer tenant-2",
  68. },
  69. timeout=5.0,
  70. )
  71. assert actual == {"foo": 2}
  72. @pytest.mark.parametrize(
  73. "path,params,expected_url",
  74. [
  75. ("", None, "http://testserver/foo"),
  76. ("bar", None, "http://testserver/foo/bar"),
  77. ("bar/", None, "http://testserver/foo/bar"),
  78. ("", {"a": 2}, "http://testserver/foo?a=2"),
  79. ("bar", {"a": 2}, "http://testserver/foo/bar?a=2"),
  80. ("bar/", {"a": 2}, "http://testserver/foo/bar?a=2"),
  81. ("", {"a": [1, 2]}, "http://testserver/foo?a=1&a=2"),
  82. ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
  83. ],
  84. )
  85. async def test_url(api_provider: ApiProvider, path, params, expected_url, request_m):
  86. await api_provider.request("GET", path, params=params)
  87. assert request_m.call_args[1]["url"] == expected_url
  88. async def test_timeout(api_provider: ApiProvider, request_m):
  89. await api_provider.request("POST", "bar", timeout=2.1)
  90. assert request_m.call_args[1]["timeout"] == 2.1
  91. @pytest.mark.parametrize(
  92. "status", [HTTPStatus.OK, HTTPStatus.NOT_FOUND, HTTPStatus.INTERNAL_SERVER_ERROR]
  93. )
  94. async def test_unexpected_content_type(api_provider: ApiProvider, response, status):
  95. response.status = int(status)
  96. response.headers["Content-Type"] = "text/plain"
  97. with pytest.raises(ApiException) as e:
  98. await api_provider.request("GET", "bar")
  99. assert e.value.status is status
  100. assert str(e.value) == f"{status}: Unexpected content type 'text/plain'"
  101. async def test_json_variant_content_type(api_provider: ApiProvider, response):
  102. response.headers["Content-Type"] = "application/something+json"
  103. actual = await api_provider.request("GET", "bar")
  104. assert actual == {"foo": 2}
  105. async def test_no_content(api_provider: ApiProvider, response):
  106. response.status = int(HTTPStatus.NO_CONTENT)
  107. response.headers = {}
  108. actual = await api_provider.request("DELETE", "bar/2")
  109. assert actual is None
  110. @pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND])
  111. async def test_error_response(api_provider: ApiProvider, response, status):
  112. response.status = int(status)
  113. with pytest.raises(ApiException) as e:
  114. await api_provider.request("GET", "bar")
  115. assert e.value.status is status
  116. assert str(e.value) == str(int(status)) + ": {'foo': 2}"
  117. async def test_no_token(api_provider: ApiProvider, request_m):
  118. api_provider._headers_factory = no_token
  119. await api_provider.request("GET", "")
  120. assert request_m.call_args[1]["headers"] == {}
  121. @pytest.mark.parametrize(
  122. "path,trailing_slash,expected",
  123. [
  124. ("bar", False, "bar"),
  125. ("bar", True, "bar/"),
  126. ("bar/", False, "bar"),
  127. ("bar/", True, "bar/"),
  128. ],
  129. )
  130. async def test_trailing_slash(
  131. api_provider: ApiProvider, path, trailing_slash, expected, request_m
  132. ):
  133. api_provider._trailing_slash = trailing_slash
  134. await api_provider.request("GET", path)
  135. assert request_m.call_args[1]["url"] == "http://testserver/foo/" + expected
  136. async def test_conflict(api_provider: ApiProvider, response):
  137. response.status = HTTPStatus.CONFLICT
  138. with pytest.raises(Conflict):
  139. await api_provider.request("GET", "bar")
  140. async def test_conflict_with_message(api_provider: ApiProvider, response):
  141. response.status = HTTPStatus.CONFLICT
  142. response.json.return_value = {"message": "foo"}
  143. with pytest.raises(Conflict, match="foo"):
  144. await api_provider.request("GET", "bar")
  145. async def test_custom_header(api_provider: ApiProvider, request_m):
  146. await api_provider.request("POST", "bar", headers={"foo": "bar"})
  147. assert request_m.call_args[1]["headers"] == {
  148. "foo": "bar",
  149. **(await api_provider._headers_factory()),
  150. }
  151. async def test_custom_header_precedes(api_provider: ApiProvider, request_m):
  152. await api_provider.request("POST", "bar", headers={"Authorization": "bar"})
  153. assert request_m.call_args[1]["headers"]["Authorization"] == "bar"
  154. async def test_session_closed(api_provider: ApiProvider, request_m):
  155. with mock.patch.object(
  156. ClientSession, "close", new_callable=mock.AsyncMock
  157. ) as close_m:
  158. await api_provider.request("GET", "")
  159. close_m.assert_awaited_once()