瀏覽代碼

Solve aiohttp 'Unclosed client session' warning (#43)

Casper van der Wel 1 年之前
父節點
當前提交
48b89993d7
共有 3 個文件被更改,包括 59 次插入37 次删除
  1. 1 1
      CHANGES.md
  2. 14 5
      clean_python/api_client/api_provider.py
  3. 44 31
      tests/api_client/test_api_provider.py

+ 1 - 1
CHANGES.md

@@ -4,7 +4,7 @@
 0.9.3 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Solved aiohttp 'Unclosed client session' warning.
 
 
 0.9.2 (2023-11-23)

+ 14 - 5
clean_python/api_client/api_provider.py

@@ -113,7 +113,15 @@ class ApiProvider:
         self._retries = retries
         self._backoff_factor = backoff_factor
         self._trailing_slash = trailing_slash
-        self._session = ClientSession()
+
+    @property
+    def _session(self) -> ClientSession:
+        # There seems to be an issue if the ClientSession is instantiated before
+        # the event loop runs. So we do that delayed in a property. Use this property
+        # in a context manager.
+        # TODO It is more efficient to reuse the connection / connection pools. One idea
+        # is to expose .session as a context manager (like with the SQLProvider.transaction)
+        return ClientSession()
 
     async def _request_with_retry(
         self,
@@ -148,10 +156,11 @@ class ApiProvider:
                 await asyncio.sleep(backoff)
 
             try:
-                response = await self._session.request(
-                    headers=actual_headers, **request_kwargs
-                )
-                await response.read()
+                async with self._session as session:
+                    response = await session.request(
+                        headers=actual_headers, **request_kwargs
+                    )
+                    await response.read()
             except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
                 if attempt == self._retries - 1:
                     raise  # propagate ClientError in case no retries left

+ 44 - 31
tests/api_client/test_api_provider.py

@@ -40,22 +40,31 @@ def response():
 
 
 @pytest.fixture
-def api_provider(tenant, response) -> ApiProvider:
+def api_provider_no_mock() -> mock.AsyncMock:
+    return ApiProvider(
+        url="http://testserver/foo/",
+        headers_factory=fake_token,
+    )
+
+
+@pytest.fixture
+def request_m() -> mock.AsyncMock:
     request = mock.AsyncMock()
     with mock.patch.object(ClientSession, "request", new=request):
-        api_provider = ApiProvider(
-            url="http://testserver/foo/",
-            headers_factory=fake_token,
-        )
-        api_provider._session.request.return_value = response
-        yield api_provider
+        yield request
+
+
+@pytest.fixture
+def api_provider(api_provider_no_mock, tenant, response, request_m) -> ApiProvider:
+    request_m.return_value = response
+    return api_provider_no_mock
 
 
-async def test_get(api_provider: ApiProvider, response):
+async def test_get(api_provider: ApiProvider, request_m):
     actual = await api_provider.request("GET", "")
 
-    assert api_provider._session.request.call_count == 1
-    assert api_provider._session.request.call_args[1] == dict(
+    assert request_m.call_count == 1
+    assert request_m.call_args[1] == dict(
         method="GET",
         url="http://testserver/foo",
         headers={"Authorization": "Bearer tenant-2"},
@@ -66,14 +75,14 @@ async def test_get(api_provider: ApiProvider, response):
     assert actual == {"foo": 2}
 
 
-async def test_post_json(api_provider: ApiProvider, response):
+async def test_post_json(api_provider: ApiProvider, response, request_m):
     response.status == int(HTTPStatus.CREATED)
-    api_provider._session.request.return_value = response
+    request_m.return_value = response
     actual = await api_provider.request("POST", "bar", json={"foo": 2})
 
-    assert api_provider._session.request.call_count == 1
+    assert request_m.call_count == 1
 
-    assert api_provider._session.request.call_args[1] == dict(
+    assert request_m.call_args[1] == dict(
         method="POST",
         url="http://testserver/foo/bar",
         data=None,
@@ -99,14 +108,14 @@ async def test_post_json(api_provider: ApiProvider, response):
         ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"),
     ],
 )
-async def test_url(api_provider: ApiProvider, path, params, expected_url):
+async def test_url(api_provider: ApiProvider, path, params, expected_url, request_m):
     await api_provider.request("GET", path, params=params)
-    assert api_provider._session.request.call_args[1]["url"] == expected_url
+    assert request_m.call_args[1]["url"] == expected_url
 
 
-async def test_timeout(api_provider: ApiProvider):
+async def test_timeout(api_provider: ApiProvider, request_m):
     await api_provider.request("POST", "bar", timeout=2.1)
-    assert api_provider._session.request.call_args[1]["timeout"] == 2.1
+    assert request_m.call_args[1]["timeout"] == 2.1
 
 
 @pytest.mark.parametrize(
@@ -147,10 +156,10 @@ async def test_error_response(api_provider: ApiProvider, response, status):
     assert str(e.value) == str(int(status)) + ": {'foo': 2}"
 
 
-async def test_no_token(api_provider: ApiProvider):
+async def test_no_token(api_provider: ApiProvider, request_m):
     api_provider._headers_factory = no_token
     await api_provider.request("GET", "")
-    assert api_provider._session.request.call_args[1]["headers"] == {}
+    assert request_m.call_args[1]["headers"] == {}
 
 
 @pytest.mark.parametrize(
@@ -163,15 +172,12 @@ async def test_no_token(api_provider: ApiProvider):
     ],
 )
 async def test_trailing_slash(
-    api_provider: ApiProvider, path, trailing_slash, expected
+    api_provider: ApiProvider, path, trailing_slash, expected, request_m
 ):
     api_provider._trailing_slash = trailing_slash
     await api_provider.request("GET", path)
 
-    assert (
-        api_provider._session.request.call_args[1]["url"]
-        == "http://testserver/foo/" + expected
-    )
+    assert request_m.call_args[1]["url"] == "http://testserver/foo/" + expected
 
 
 async def test_conflict(api_provider: ApiProvider, response):
@@ -189,16 +195,23 @@ async def test_conflict_with_message(api_provider: ApiProvider, response):
         await api_provider.request("GET", "bar")
 
 
-async def test_custom_header(api_provider: ApiProvider):
+async def test_custom_header(api_provider: ApiProvider, request_m):
     await api_provider.request("POST", "bar", headers={"foo": "bar"})
-    assert api_provider._session.request.call_args[1]["headers"] == {
+    assert request_m.call_args[1]["headers"] == {
         "foo": "bar",
         **(await api_provider._headers_factory()),
     }
 
 
-async def test_custom_header_precedes(api_provider: ApiProvider):
+async def test_custom_header_precedes(api_provider: ApiProvider, request_m):
     await api_provider.request("POST", "bar", headers={"Authorization": "bar"})
-    assert (
-        api_provider._session.request.call_args[1]["headers"]["Authorization"] == "bar"
-    )
+    assert request_m.call_args[1]["headers"]["Authorization"] == "bar"
+
+
+async def test_session_closed(api_provider: ApiProvider, request_m):
+    with mock.patch.object(
+        ClientSession, "close", new_callable=mock.AsyncMock
+    ) as close_m:
+        await api_provider.request("GET", "")
+
+    close_m.assert_awaited_once()