test_api_provider.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from http import HTTPStatus
  2. from unittest import mock
  3. import pytest
  4. from aiohttp import ClientSession
  5. from clean_python import Conflict
  6. from clean_python import ctx
  7. from clean_python import Tenant
  8. from clean_python.api_client import ApiException
  9. from clean_python.api_client import ApiProvider
  10. MODULE = "clean_python.api_client.api_provider"
  11. async def fake_token():
  12. return {"Authorization": f"Bearer tenant-{ctx.tenant.id}"}
  13. async def no_token():
  14. return {}
  15. @pytest.fixture
  16. def tenant() -> Tenant:
  17. ctx.tenant = Tenant(id=2, name="")
  18. yield ctx.tenant
  19. ctx.tenant = None
  20. @pytest.fixture
  21. def response():
  22. # this mocks the aiohttp.ClientResponse:
  23. response = mock.Mock()
  24. response.status = int(HTTPStatus.OK)
  25. response.headers = {"Content-Type": "application/json"}
  26. response.json = mock.AsyncMock(return_value={"foo": 2})
  27. response.read = mock.AsyncMock()
  28. return response
  29. @pytest.fixture
  30. def api_provider(tenant, response) -> ApiProvider:
  31. request = mock.AsyncMock()
  32. with mock.patch.object(ClientSession, "request", new=request):
  33. api_provider = ApiProvider(
  34. url="http://testserver/foo/",
  35. headers_factory=fake_token,
  36. )
  37. api_provider._session.request.return_value = response
  38. yield api_provider
  39. async def test_get(api_provider: ApiProvider, response):
  40. actual = await api_provider.request("GET", "")
  41. assert api_provider._session.request.call_count == 1
  42. assert api_provider._session.request.call_args[1] == dict(
  43. method="GET",
  44. url="http://testserver/foo",
  45. headers={"Authorization": "Bearer tenant-2"},
  46. timeout=5.0,
  47. data=None,
  48. json=None,
  49. )
  50. assert actual == {"foo": 2}
  51. async def test_post_json(api_provider: ApiProvider, response):
  52. response.status == int(HTTPStatus.CREATED)
  53. api_provider._session.request.return_value = response
  54. actual = await api_provider.request("POST", "bar", json={"foo": 2})
  55. assert api_provider._session.request.call_count == 1
  56. assert api_provider._session.request.call_args[1] == dict(
  57. method="POST",
  58. url="http://testserver/foo/bar",
  59. data=None,
  60. json={"foo": 2},
  61. headers={
  62. "Authorization": "Bearer tenant-2",
  63. },
  64. timeout=5.0,
  65. )
  66. assert actual == {"foo": 2}
  67. @pytest.mark.parametrize(
  68. "path,params,expected_url",
  69. [
  70. ("", None, "http://testserver/foo"),
  71. ("bar", None, "http://testserver/foo/bar"),
  72. ("bar/", None, "http://testserver/foo/bar"),
  73. ("", {"a": 2}, "http://testserver/foo?a=2"),
  74. ("bar", {"a": 2}, "http://testserver/foo/bar?a=2"),
  75. ("bar/", {"a": 2}, "http://testserver/foo/bar?a=2"),
  76. ("", {"a": [1, 2]}, "http://testserver/foo?a=1&a=2"),
  77. ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
  78. ],
  79. )
  80. async def test_url(api_provider: ApiProvider, path, params, expected_url):
  81. await api_provider.request("GET", path, params=params)
  82. assert api_provider._session.request.call_args[1]["url"] == expected_url
  83. async def test_timeout(api_provider: ApiProvider):
  84. await api_provider.request("POST", "bar", timeout=2.1)
  85. assert api_provider._session.request.call_args[1]["timeout"] == 2.1
  86. @pytest.mark.parametrize(
  87. "status", [HTTPStatus.OK, HTTPStatus.NOT_FOUND, HTTPStatus.INTERNAL_SERVER_ERROR]
  88. )
  89. async def test_unexpected_content_type(api_provider: ApiProvider, response, status):
  90. response.status = int(status)
  91. response.headers["Content-Type"] = "text/plain"
  92. with pytest.raises(ApiException) as e:
  93. await api_provider.request("GET", "bar")
  94. assert e.value.status is status
  95. assert str(e.value) == f"{status}: Unexpected content type 'text/plain'"
  96. async def test_json_variant_content_type(api_provider: ApiProvider, response):
  97. response.headers["Content-Type"] = "application/something+json"
  98. actual = await api_provider.request("GET", "bar")
  99. assert actual == {"foo": 2}
  100. async def test_no_content(api_provider: ApiProvider, response):
  101. response.status = int(HTTPStatus.NO_CONTENT)
  102. response.headers = {}
  103. actual = await api_provider.request("DELETE", "bar/2")
  104. assert actual is None
  105. @pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND])
  106. async def test_error_response(api_provider: ApiProvider, response, status):
  107. response.status = int(status)
  108. with pytest.raises(ApiException) as e:
  109. await api_provider.request("GET", "bar")
  110. assert e.value.status is status
  111. assert str(e.value) == str(int(status)) + ": {'foo': 2}"
  112. async def test_no_token(api_provider: ApiProvider):
  113. api_provider._headers_factory = no_token
  114. await api_provider.request("GET", "")
  115. assert api_provider._session.request.call_args[1]["headers"] == {}
  116. @pytest.mark.parametrize(
  117. "path,trailing_slash,expected",
  118. [
  119. ("bar", False, "bar"),
  120. ("bar", True, "bar/"),
  121. ("bar/", False, "bar"),
  122. ("bar/", True, "bar/"),
  123. ],
  124. )
  125. async def test_trailing_slash(
  126. api_provider: ApiProvider, path, trailing_slash, expected
  127. ):
  128. api_provider._trailing_slash = trailing_slash
  129. await api_provider.request("GET", path)
  130. assert (
  131. api_provider._session.request.call_args[1]["url"]
  132. == "http://testserver/foo/" + expected
  133. )
  134. async def test_conflict(api_provider: ApiProvider, response):
  135. response.status = HTTPStatus.CONFLICT
  136. with pytest.raises(Conflict):
  137. await api_provider.request("GET", "bar")
  138. async def test_conflict_with_message(api_provider: ApiProvider, response):
  139. response.status = HTTPStatus.CONFLICT
  140. response.json.return_value = {"message": "foo"}
  141. with pytest.raises(Conflict, match="foo"):
  142. await api_provider.request("GET", "bar")