Parcourir la source

Revert "Solve aiohttp 'Unclosed client session' warning (#43)"

This reverts commit 48b89993d739f6f6abd695027eca971b14bd0bc9.
Casper van der Wel il y a 1 an
Parent
commit
232336eed4
3 fichiers modifiés avec 37 ajouts et 59 suppressions
  1. 1 1
      CHANGES.md
  2. 5 14
      clean_python/api_client/api_provider.py
  3. 31 44
      tests/api_client/test_api_provider.py

+ 1 - 1
CHANGES.md

@@ -4,7 +4,7 @@
 0.9.3 (unreleased)
 ------------------
 
-- Solved aiohttp 'Unclosed client session' warning.
+- Nothing changed yet.
 
 
 0.9.2 (2023-11-23)

+ 5 - 14
clean_python/api_client/api_provider.py

@@ -113,15 +113,7 @@ class ApiProvider:
         self._retries = retries
         self._backoff_factor = backoff_factor
         self._trailing_slash = trailing_slash
-
-    @property
-    def _session(self) -> ClientSession:
-        # There seems to be an issue if the ClientSession is instantiated before
-        # the event loop runs. So we do that delayed in a property. Use this property
-        # in a context manager.
-        # TODO It is more efficient to reuse the connection / connection pools. One idea
-        # is to expose .session as a context manager (like with the SQLProvider.transaction)
-        return ClientSession()
+        self._session = ClientSession()
 
     async def _request_with_retry(
         self,
@@ -156,11 +148,10 @@ class ApiProvider:
                 await asyncio.sleep(backoff)
 
             try:
-                async with self._session as session:
-                    response = await session.request(
-                        headers=actual_headers, **request_kwargs
-                    )
-                    await response.read()
+                response = await self._session.request(
+                    headers=actual_headers, **request_kwargs
+                )
+                await response.read()
             except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
                 if attempt == self._retries - 1:
                     raise  # propagate ClientError in case no retries left

+ 31 - 44
tests/api_client/test_api_provider.py

@@ -40,31 +40,22 @@ def response():
 
 
 @pytest.fixture
-def api_provider_no_mock() -> mock.AsyncMock:
-    return ApiProvider(
-        url="http://testserver/foo/",
-        headers_factory=fake_token,
-    )
-
-
-@pytest.fixture
-def request_m() -> mock.AsyncMock:
+def api_provider(tenant, response) -> ApiProvider:
     request = mock.AsyncMock()
     with mock.patch.object(ClientSession, "request", new=request):
-        yield request
-
-
-@pytest.fixture
-def api_provider(api_provider_no_mock, tenant, response, request_m) -> ApiProvider:
-    request_m.return_value = response
-    return api_provider_no_mock
+        api_provider = ApiProvider(
+            url="http://testserver/foo/",
+            headers_factory=fake_token,
+        )
+        api_provider._session.request.return_value = response
+        yield api_provider
 
 
-async def test_get(api_provider: ApiProvider, request_m):
+async def test_get(api_provider: ApiProvider, response):
     actual = await api_provider.request("GET", "")
 
-    assert request_m.call_count == 1
-    assert request_m.call_args[1] == dict(
+    assert api_provider._session.request.call_count == 1
+    assert api_provider._session.request.call_args[1] == dict(
         method="GET",
         url="http://testserver/foo",
         headers={"Authorization": "Bearer tenant-2"},
@@ -75,14 +66,14 @@ async def test_get(api_provider: ApiProvider, request_m):
     assert actual == {"foo": 2}
 
 
-async def test_post_json(api_provider: ApiProvider, response, request_m):
+async def test_post_json(api_provider: ApiProvider, response):
     response.status == int(HTTPStatus.CREATED)
-    request_m.return_value = response
+    api_provider._session.request.return_value = response
     actual = await api_provider.request("POST", "bar", json={"foo": 2})
 
-    assert request_m.call_count == 1
+    assert api_provider._session.request.call_count == 1
 
-    assert request_m.call_args[1] == dict(
+    assert api_provider._session.request.call_args[1] == dict(
         method="POST",
         url="http://testserver/foo/bar",
         data=None,
@@ -108,14 +99,14 @@ async def test_post_json(api_provider: ApiProvider, response, request_m):
         ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
     ],
 )
-async def test_url(api_provider: ApiProvider, path, params, expected_url, request_m):
+async def test_url(api_provider: ApiProvider, path, params, expected_url):
     await api_provider.request("GET", path, params=params)
-    assert request_m.call_args[1]["url"] == expected_url
+    assert api_provider._session.request.call_args[1]["url"] == expected_url
 
 
-async def test_timeout(api_provider: ApiProvider, request_m):
+async def test_timeout(api_provider: ApiProvider):
     await api_provider.request("POST", "bar", timeout=2.1)
-    assert request_m.call_args[1]["timeout"] == 2.1
+    assert api_provider._session.request.call_args[1]["timeout"] == 2.1
 
 
 @pytest.mark.parametrize(
@@ -156,10 +147,10 @@ async def test_error_response(api_provider: ApiProvider, response, status):
     assert str(e.value) == str(int(status)) + ": {'foo': 2}"
 
 
-async def test_no_token(api_provider: ApiProvider, request_m):
+async def test_no_token(api_provider: ApiProvider):
     api_provider._headers_factory = no_token
     await api_provider.request("GET", "")
-    assert request_m.call_args[1]["headers"] == {}
+    assert api_provider._session.request.call_args[1]["headers"] == {}
 
 
 @pytest.mark.parametrize(
@@ -172,12 +163,15 @@ async def test_no_token(api_provider: ApiProvider, request_m):
     ],
 )
 async def test_trailing_slash(
-    api_provider: ApiProvider, path, trailing_slash, expected, request_m
+    api_provider: ApiProvider, path, trailing_slash, expected
 ):
     api_provider._trailing_slash = trailing_slash
     await api_provider.request("GET", path)
 
-    assert request_m.call_args[1]["url"] == "http://testserver/foo/" + expected
+    assert (
+        api_provider._session.request.call_args[1]["url"]
+        == "http://testserver/foo/" + expected
+    )
 
 
 async def test_conflict(api_provider: ApiProvider, response):
@@ -195,23 +189,16 @@ async def test_conflict_with_message(api_provider: ApiProvider, response):
         await api_provider.request("GET", "bar")
 
 
-async def test_custom_header(api_provider: ApiProvider, request_m):
+async def test_custom_header(api_provider: ApiProvider):
     await api_provider.request("POST", "bar", headers={"foo": "bar"})
-    assert request_m.call_args[1]["headers"] == {
+    assert api_provider._session.request.call_args[1]["headers"] == {
         "foo": "bar",
         **(await api_provider._headers_factory()),
     }
 
 
-async def test_custom_header_precedes(api_provider: ApiProvider, request_m):
+async def test_custom_header_precedes(api_provider: ApiProvider):
     await api_provider.request("POST", "bar", headers={"Authorization": "bar"})
-    assert request_m.call_args[1]["headers"]["Authorization"] == "bar"
-
-
-async def test_session_closed(api_provider: ApiProvider, request_m):
-    with mock.patch.object(
-        ClientSession, "close", new_callable=mock.AsyncMock
-    ) as close_m:
-        await api_provider.request("GET", "")
-
-    close_m.assert_awaited_once()
+    assert (
+        api_provider._session.request.call_args[1]["headers"]["Authorization"] == "bar"
+    )