Pārlūkot izejas kodu

Revert changes in 0.9.1 for CCTokenGateway (#41)

Casper van der Wel 1 gadu atpakaļ
vecāks
revīzija
d9ec377ab5

+ 3 - 1
CHANGES.md

@@ -4,7 +4,9 @@
 0.9.2 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Revert changes done in 0.9.1 in CCTokenGateway.
+
+- Added CCTokenGateway.fetch_headers()
 
 
 0.9.1 (2023-11-23)

+ 28 - 18
clean_python/oauth2/client_credentials.py

@@ -2,6 +2,7 @@ import base64
 import json
 import time
 from functools import lru_cache
+from typing import Dict
 
 from aiohttp import BasicAuth
 from async_lru import alru_cache
@@ -38,6 +39,10 @@ def is_token_usable(token: str, leeway: int) -> bool:
     return refresh_on >= int(time.time())
 
 
+def get_auth_headers(client_id: str, client_secret: str) -> Dict[str, str]:
+    return {"Authorization": BasicAuth(client_id, client_secret).encode()}
+
+
 class OAuth2CCSettings(BaseModel):
     token_url: AnyHttpUrl
     client_id: str
@@ -53,17 +58,18 @@ class CCTokenGateway:
         self.timeout = settings.timeout
         self.leeway = settings.leeway
 
+        auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
+
         async def headers_factory():
-            auth = BasicAuth(settings.client_id, settings.client_secret)
-            return {"Authorization": auth.encode()}
+            return auth_headers
 
         self.provider = ApiProvider(
             url=settings.token_url, headers_factory=headers_factory
         )
         # This binds the cache to the CCTokenGateway instance (and not the class)
-        self.cached_headers_factory = alru_cache(self._headers_factory)
+        self.cached_fetch_token = alru_cache(self._fetch_token)
 
-    async def _headers_factory(self) -> str:
+    async def _fetch_token(self) -> str:
         response = await self.provider.request(
             method="POST",
             path="",
@@ -73,13 +79,16 @@ class CCTokenGateway:
         assert response is not None
         return response["access_token"]
 
-    async def headers_factory(self) -> str:
-        token_str = await self.cached_headers_factory()
+    async def fetch_token(self) -> str:
+        token_str = await self.cached_fetch_token()
         if not is_token_usable(token_str, self.leeway):
-            self.cached_headers_factory.cache_clear()
-            token_str = await self.cached_headers_factory()
+            self.cached_fetch_token.cache_clear()
+            token_str = await self.cached_fetch_token()
         return token_str
 
+    async def fetch_headers(self) -> Dict[str, str]:
+        return {"Authorization": f"Bearer {await self.fetch_token()}"}
+
 
 # Copy-paste of async version:
 
@@ -90,17 +99,15 @@ class SyncCCTokenGateway:
         self.timeout = settings.timeout
         self.leeway = settings.leeway
 
-        def headers_factory():
-            auth = BasicAuth(settings.client_id, settings.client_secret)
-            return {"Authorization": auth.encode()}
+        auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
 
         self.provider = SyncApiProvider(
-            url=settings.token_url, headers_factory=headers_factory
+            url=settings.token_url, headers_factory=lambda: auth_headers
         )
         # This binds the cache to the SyncCCTokenGateway instance (and not the class)
-        self.cached_headers_factory = lru_cache(self._headers_factory)
+        self.cached_fetch_token = lru_cache(self._fetch_token)
 
-    def _headers_factory(self) -> str:
+    def _fetch_token(self) -> str:
         response = self.provider.request(
             method="POST",
             path="",
@@ -110,9 +117,12 @@ class SyncCCTokenGateway:
         assert response is not None
         return response["access_token"]
 
-    def headers_factory(self) -> str:
-        token_str = self.cached_headers_factory()
+    def fetch_token(self) -> str:
+        token_str = self.cached_fetch_token()
         if not is_token_usable(token_str, self.leeway):
-            self.cached_headers_factory.cache_clear()
-            token_str = self.cached_headers_factory()
+            self.cached_fetch_token.cache_clear()
+            token_str = self.cached_fetch_token()
         return token_str
+
+    def fetch_headers(self) -> Dict[str, str]:
+        return {"Authorization": f"Bearer {self.fetch_token()}"}

+ 4 - 4
integration_tests/test_int_client_credentials.py

@@ -22,8 +22,8 @@ def gateway(settings) -> CCTokenGateway:
     return CCTokenGateway(settings)
 
 
-async def test_headers_factory(gateway: CCTokenGateway):
-    response = await gateway._headers_factory()
+async def test_fetch_token(gateway: CCTokenGateway):
+    response = await gateway._fetch_token()
     assert is_token_usable(response, 0)
 
 
@@ -32,6 +32,6 @@ def sync_gateway(settings) -> SyncCCTokenGateway:
     return SyncCCTokenGateway(settings)
 
 
-def test_headers_factory_sync(sync_gateway: SyncCCTokenGateway):
-    response = sync_gateway._headers_factory()
+def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
+    response = sync_gateway._fetch_token()
     assert is_token_usable(response, 0)

+ 24 - 12
tests/oauth2/test_client_credentials.py

@@ -55,10 +55,10 @@ def test_is_token_usable(expires_in, leeway, expected):
     assert is_token_usable(token, leeway) is expected
 
 
-async def test_headers_factory(gateway: CCTokenGateway):
+async def test_fetch_token(gateway: CCTokenGateway):
     gateway.provider.request.return_value = {"access_token": "foo"}
 
-    token = await gateway._headers_factory()
+    token = await gateway._fetch_token()
 
     assert token == "foo"
 
@@ -70,18 +70,18 @@ async def test_headers_factory(gateway: CCTokenGateway):
     )
 
 
-async def test_headers_factory_cache(gateway: CCTokenGateway):
+async def test_fetch_token_cache(gateway: CCTokenGateway):
     # empty cache: provider gets called
     token = get_token({})
     gateway.provider.request.return_value = {"access_token": token}
-    actual = await gateway.headers_factory()
+    actual = await gateway.fetch_token()
     assert actual == token
     assert gateway.provider.request.called
 
     gateway.provider.request.reset_mock()
 
     # cache is filled: provider is not called
-    actual = await gateway.headers_factory()
+    actual = await gateway.fetch_token()
     assert actual == token
     assert not gateway.provider.request.called
 
@@ -89,15 +89,15 @@ async def test_headers_factory_cache(gateway: CCTokenGateway):
 
     # token is not usable so it is refreshed:
     with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
-        actual = await gateway.headers_factory()
+        actual = await gateway.fetch_token()
         assert actual == token
         assert gateway.provider.request.called
 
 
-def test_headers_factory_sync(sync_gateway: SyncCCTokenGateway):
+def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
     sync_gateway.provider.request.return_value = {"access_token": "foo"}
 
-    token = sync_gateway._headers_factory()
+    token = sync_gateway._fetch_token()
 
     assert token == "foo"
 
@@ -109,18 +109,18 @@ def test_headers_factory_sync(sync_gateway: SyncCCTokenGateway):
     )
 
 
-def test_headers_factory_sync_cache(sync_gateway: SyncCCTokenGateway):
+def test_fetch_token_sync_cache(sync_gateway: SyncCCTokenGateway):
     # empty cache: provider gets called
     token = get_token({})
     sync_gateway.provider.request.return_value = {"access_token": token}
-    actual = sync_gateway.headers_factory()
+    actual = sync_gateway.fetch_token()
     assert actual == token
     assert sync_gateway.provider.request.called
 
     sync_gateway.provider.request.reset_mock()
 
     # cache is filled: provider is not called
-    actual = sync_gateway.headers_factory()
+    actual = sync_gateway.fetch_token()
     assert actual == token
     assert not sync_gateway.provider.request.called
 
@@ -128,6 +128,18 @@ def test_headers_factory_sync_cache(sync_gateway: SyncCCTokenGateway):
 
     # token is not usable so it is refreshed:
     with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
-        actual = sync_gateway.headers_factory()
+        actual = sync_gateway.fetch_token()
         assert actual == token
         assert sync_gateway.provider.request.called
+
+
+async def test_fetch_headers(gateway: CCTokenGateway):
+    gateway.provider.request.return_value = {"access_token": "foo"}
+
+    await gateway.fetch_headers() == {"Authorization": "Bearer foo"}
+
+
+def test_fetch_headers_sync(sync_gateway: SyncCCTokenGateway):
+    sync_gateway.provider.request.return_value = {"access_token": "foo"}
+
+    assert sync_gateway.fetch_headers() == {"Authorization": "Bearer foo"}