|
@@ -40,31 +40,22 @@ def response():
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
-def api_provider_no_mock() -> mock.AsyncMock:
|
|
|
- return ApiProvider(
|
|
|
- url="http://testserver/foo/",
|
|
|
- headers_factory=fake_token,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-@pytest.fixture
|
|
|
-def request_m() -> mock.AsyncMock:
|
|
|
+def api_provider(tenant, response) -> ApiProvider:
|
|
|
request = mock.AsyncMock()
|
|
|
with mock.patch.object(ClientSession, "request", new=request):
|
|
|
- yield request
|
|
|
-
|
|
|
-
|
|
|
-@pytest.fixture
|
|
|
-def api_provider(api_provider_no_mock, tenant, response, request_m) -> ApiProvider:
|
|
|
- request_m.return_value = response
|
|
|
- return api_provider_no_mock
|
|
|
+ api_provider = ApiProvider(
|
|
|
+ url="http://testserver/foo/",
|
|
|
+ headers_factory=fake_token,
|
|
|
+ )
|
|
|
+ api_provider._session.request.return_value = response
|
|
|
+ yield api_provider
|
|
|
|
|
|
|
|
|
-async def test_get(api_provider: ApiProvider, request_m):
|
|
|
+async def test_get(api_provider: ApiProvider, response):
|
|
|
actual = await api_provider.request("GET", "")
|
|
|
|
|
|
- assert request_m.call_count == 1
|
|
|
- assert request_m.call_args[1] == dict(
|
|
|
+ assert api_provider._session.request.call_count == 1
|
|
|
+ assert api_provider._session.request.call_args[1] == dict(
|
|
|
method="GET",
|
|
|
url="http://testserver/foo",
|
|
|
headers={"Authorization": "Bearer tenant-2"},
|
|
@@ -75,14 +66,14 @@ async def test_get(api_provider: ApiProvider, request_m):
|
|
|
assert actual == {"foo": 2}
|
|
|
|
|
|
|
|
|
-async def test_post_json(api_provider: ApiProvider, response, request_m):
|
|
|
+async def test_post_json(api_provider: ApiProvider, response):
|
|
|
response.status == int(HTTPStatus.CREATED)
|
|
|
- request_m.return_value = response
|
|
|
+ api_provider._session.request.return_value = response
|
|
|
actual = await api_provider.request("POST", "bar", json={"foo": 2})
|
|
|
|
|
|
- assert request_m.call_count == 1
|
|
|
+ assert api_provider._session.request.call_count == 1
|
|
|
|
|
|
- assert request_m.call_args[1] == dict(
|
|
|
+ assert api_provider._session.request.call_args[1] == dict(
|
|
|
method="POST",
|
|
|
url="http://testserver/foo/bar",
|
|
|
data=None,
|
|
@@ -108,14 +99,14 @@ async def test_post_json(api_provider: ApiProvider, response, request_m):
|
|
|
("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
|
|
|
],
|
|
|
)
|
|
|
-async def test_url(api_provider: ApiProvider, path, params, expected_url, request_m):
|
|
|
+async def test_url(api_provider: ApiProvider, path, params, expected_url):
|
|
|
await api_provider.request("GET", path, params=params)
|
|
|
- assert request_m.call_args[1]["url"] == expected_url
|
|
|
+ assert api_provider._session.request.call_args[1]["url"] == expected_url
|
|
|
|
|
|
|
|
|
-async def test_timeout(api_provider: ApiProvider, request_m):
|
|
|
+async def test_timeout(api_provider: ApiProvider):
|
|
|
await api_provider.request("POST", "bar", timeout=2.1)
|
|
|
- assert request_m.call_args[1]["timeout"] == 2.1
|
|
|
+ assert api_provider._session.request.call_args[1]["timeout"] == 2.1
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
@@ -156,10 +147,10 @@ async def test_error_response(api_provider: ApiProvider, response, status):
|
|
|
assert str(e.value) == str(int(status)) + ": {'foo': 2}"
|
|
|
|
|
|
|
|
|
-async def test_no_token(api_provider: ApiProvider, request_m):
|
|
|
+async def test_no_token(api_provider: ApiProvider):
|
|
|
api_provider._headers_factory = no_token
|
|
|
await api_provider.request("GET", "")
|
|
|
- assert request_m.call_args[1]["headers"] == {}
|
|
|
+ assert api_provider._session.request.call_args[1]["headers"] == {}
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
@@ -172,12 +163,15 @@ async def test_no_token(api_provider: ApiProvider, request_m):
|
|
|
],
|
|
|
)
|
|
|
async def test_trailing_slash(
|
|
|
- api_provider: ApiProvider, path, trailing_slash, expected, request_m
|
|
|
+ api_provider: ApiProvider, path, trailing_slash, expected
|
|
|
):
|
|
|
api_provider._trailing_slash = trailing_slash
|
|
|
await api_provider.request("GET", path)
|
|
|
|
|
|
- assert request_m.call_args[1]["url"] == "http://testserver/foo/" + expected
|
|
|
+ assert (
|
|
|
+ api_provider._session.request.call_args[1]["url"]
|
|
|
+ == "http://testserver/foo/" + expected
|
|
|
+ )
|
|
|
|
|
|
|
|
|
async def test_conflict(api_provider: ApiProvider, response):
|
|
@@ -195,23 +189,16 @@ async def test_conflict_with_message(api_provider: ApiProvider, response):
|
|
|
await api_provider.request("GET", "bar")
|
|
|
|
|
|
|
|
|
-async def test_custom_header(api_provider: ApiProvider, request_m):
|
|
|
+async def test_custom_header(api_provider: ApiProvider):
|
|
|
await api_provider.request("POST", "bar", headers={"foo": "bar"})
|
|
|
- assert request_m.call_args[1]["headers"] == {
|
|
|
+ assert api_provider._session.request.call_args[1]["headers"] == {
|
|
|
"foo": "bar",
|
|
|
**(await api_provider._headers_factory()),
|
|
|
}
|
|
|
|
|
|
|
|
|
-async def test_custom_header_precedes(api_provider: ApiProvider, request_m):
|
|
|
+async def test_custom_header_precedes(api_provider: ApiProvider):
|
|
|
await api_provider.request("POST", "bar", headers={"Authorization": "bar"})
|
|
|
- assert request_m.call_args[1]["headers"]["Authorization"] == "bar"
|
|
|
-
|
|
|
-
|
|
|
-async def test_session_closed(api_provider: ApiProvider, request_m):
|
|
|
- with mock.patch.object(
|
|
|
- ClientSession, "close", new_callable=mock.AsyncMock
|
|
|
- ) as close_m:
|
|
|
- await api_provider.request("GET", "")
|
|
|
-
|
|
|
- close_m.assert_awaited_once()
|
|
|
+ assert (
|
|
|
+ api_provider._session.request.call_args[1]["headers"]["Authorization"] == "bar"
|
|
|
+ )
|