test_sync_api_provider.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # This module is a copy paste of test_api_provider.py
  2. from http import HTTPStatus
  3. from unittest import mock
  4. import pytest
  5. from clean_python import Conflict
  6. from clean_python import ctx
  7. from clean_python import Tenant
  8. from clean_python.api_client import ApiException
  9. from clean_python.api_client import FileFormPost
  10. from clean_python.api_client import SyncApiProvider
  11. MODULE = "clean_python.api_client.sync_api_provider"
  12. @pytest.fixture
  13. def tenant() -> Tenant:
  14. ctx.tenant = Tenant(id=2, name="")
  15. yield ctx.tenant
  16. ctx.tenant = None
  17. @pytest.fixture
  18. def response():
  19. response = mock.Mock()
  20. response.status = int(HTTPStatus.OK)
  21. response.headers = {"Content-Type": "application/json"}
  22. response.data = b'{"foo": 2}'
  23. return response
  24. @pytest.fixture
  25. def api_provider(tenant, response) -> SyncApiProvider:
  26. with mock.patch(MODULE + ".PoolManager"):
  27. api_provider = SyncApiProvider(
  28. url="http://testserver/foo/",
  29. headers_factory=lambda: {"Authorization": f"Bearer tenant-{ctx.tenant.id}"},
  30. )
  31. api_provider._pool.request.return_value = response
  32. yield api_provider
  33. def test_get(api_provider: SyncApiProvider, response):
  34. actual = api_provider.request("GET", "")
  35. assert api_provider._pool.request.call_count == 1
  36. assert api_provider._pool.request.call_args[1] == dict(
  37. method="GET",
  38. url="http://testserver/foo",
  39. headers={"Authorization": "Bearer tenant-2"},
  40. timeout=5.0,
  41. )
  42. assert actual == {"foo": 2}
  43. def test_post_json(api_provider: SyncApiProvider, response):
  44. response.status == int(HTTPStatus.CREATED)
  45. api_provider._pool.request.return_value = response
  46. actual = api_provider.request("POST", "bar", json={"foo": 2})
  47. assert api_provider._pool.request.call_count == 1
  48. assert api_provider._pool.request.call_args[1] == dict(
  49. method="POST",
  50. url="http://testserver/foo/bar",
  51. body=b'{"foo": 2}',
  52. headers={
  53. "Content-Type": "application/json",
  54. "Authorization": "Bearer tenant-2",
  55. },
  56. timeout=5.0,
  57. )
  58. assert actual == {"foo": 2}
  59. @pytest.mark.parametrize(
  60. "path,params,expected_url",
  61. [
  62. ("", None, "http://testserver/foo"),
  63. ("bar", None, "http://testserver/foo/bar"),
  64. ("bar/", None, "http://testserver/foo/bar"),
  65. ("", {"a": 2}, "http://testserver/foo?a=2"),
  66. ("bar", {"a": 2}, "http://testserver/foo/bar?a=2"),
  67. ("bar/", {"a": 2}, "http://testserver/foo/bar?a=2"),
  68. ("", {"a": [1, 2]}, "http://testserver/foo?a=1&a=2"),
  69. ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
  70. ],
  71. )
  72. def test_url(api_provider: SyncApiProvider, path, params, expected_url):
  73. api_provider.request("GET", path, params=params)
  74. assert api_provider._pool.request.call_args[1]["url"] == expected_url
  75. def test_timeout(api_provider: SyncApiProvider):
  76. api_provider.request("POST", "bar", timeout=2.1)
  77. assert api_provider._pool.request.call_args[1]["timeout"] == 2.1
  78. @pytest.mark.parametrize(
  79. "status", [HTTPStatus.OK, HTTPStatus.NOT_FOUND, HTTPStatus.INTERNAL_SERVER_ERROR]
  80. )
  81. def test_unexpected_content_type(api_provider: SyncApiProvider, response, status):
  82. response.status = int(status)
  83. response.headers["Content-Type"] = "text/plain"
  84. with pytest.raises(ApiException) as e:
  85. api_provider.request("GET", "bar")
  86. assert e.value.status is status
  87. assert str(e.value) == f"{status}: Unexpected content type 'text/plain'"
  88. def test_json_variant_content_type(api_provider: SyncApiProvider, response):
  89. response.headers["Content-Type"] = "application/something+json"
  90. actual = api_provider.request("GET", "bar")
  91. assert actual == {"foo": 2}
  92. def test_no_content(api_provider: SyncApiProvider, response):
  93. response.status = int(HTTPStatus.NO_CONTENT)
  94. response.headers = {}
  95. actual = api_provider.request("DELETE", "bar/2")
  96. assert actual is None
  97. @pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND])
  98. def test_error_response(api_provider: SyncApiProvider, response, status):
  99. response.status = int(status)
  100. with pytest.raises(ApiException) as e:
  101. api_provider.request("GET", "bar")
  102. assert e.value.status is status
  103. assert str(e.value) == str(int(status)) + ": {'foo': 2}"
  104. @mock.patch(MODULE + ".PoolManager", new=mock.Mock())
  105. def test_no_token(response, tenant):
  106. api_provider = SyncApiProvider(
  107. url="http://testserver/foo/", headers_factory=lambda: {}
  108. )
  109. api_provider._pool.request.return_value = response
  110. api_provider.request("GET", "")
  111. assert api_provider._pool.request.call_args[1]["headers"] == {}
  112. @pytest.mark.parametrize(
  113. "path,trailing_slash,expected",
  114. [
  115. ("bar", False, "bar"),
  116. ("bar", True, "bar/"),
  117. ("bar/", False, "bar"),
  118. ("bar/", True, "bar/"),
  119. ],
  120. )
  121. def test_trailing_slash(api_provider: SyncApiProvider, path, trailing_slash, expected):
  122. api_provider._trailing_slash = trailing_slash
  123. api_provider.request("GET", path)
  124. assert (
  125. api_provider._pool.request.call_args[1]["url"]
  126. == "http://testserver/foo/" + expected
  127. )
  128. def test_post_file(api_provider: SyncApiProvider):
  129. api_provider.request(
  130. "POST",
  131. "bar",
  132. file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
  133. )
  134. assert api_provider._pool.request.call_count == 1
  135. assert api_provider._pool.request.call_args[1] == dict(
  136. method="POST",
  137. url="http://testserver/foo/bar",
  138. fields={"x": ("test.zip", b"foo", "application/octet-stream")},
  139. headers={
  140. "Authorization": "Bearer tenant-2",
  141. },
  142. timeout=5.0,
  143. encode_multipart=True,
  144. )
  145. def test_post_file_with_fields(api_provider: SyncApiProvider):
  146. api_provider.request(
  147. "POST",
  148. "bar",
  149. fields={"a": "b"},
  150. file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
  151. )
  152. assert api_provider._pool.request.call_count == 1
  153. assert api_provider._pool.request.call_args[1] == dict(
  154. method="POST",
  155. url="http://testserver/foo/bar",
  156. fields={"a": "b", "x": ("test.zip", b"foo", "application/octet-stream")},
  157. headers={
  158. "Authorization": "Bearer tenant-2",
  159. },
  160. timeout=5.0,
  161. encode_multipart=True,
  162. )
  163. def test_conflict(api_provider: SyncApiProvider, response):
  164. response.status = HTTPStatus.CONFLICT
  165. with pytest.raises(Conflict):
  166. api_provider.request("GET", "bar")
  167. def test_conflict_with_message(api_provider: SyncApiProvider, response):
  168. response.status = HTTPStatus.CONFLICT
  169. response.json.return_value = {"message": "foo"}
  170. with pytest.raises(Conflict, match="foo"):
  171. api_provider.request("GET", "bar")
  172. def test_custom_header(api_provider: SyncApiProvider):
  173. api_provider.request("POST", "bar", headers={"foo": "bar"})
  174. assert api_provider._pool.request.call_args[1]["headers"] == {
  175. "foo": "bar",
  176. **api_provider._headers_factory(),
  177. }
  178. def test_custom_header_precedes(api_provider: SyncApiProvider):
  179. api_provider.request("POST", "bar", headers={"Authorization": "bar"})
  180. assert api_provider._pool.request.call_args[1]["headers"]["Authorization"] == "bar"