|
@@ -2,6 +2,7 @@ import base64
|
|
|
import json
|
|
|
import time
|
|
|
from functools import lru_cache
|
|
|
+from typing import Dict
|
|
|
|
|
|
from aiohttp import BasicAuth
|
|
|
from async_lru import alru_cache
|
|
@@ -38,6 +39,10 @@ def is_token_usable(token: str, leeway: int) -> bool:
|
|
|
return refresh_on >= int(time.time())
|
|
|
|
|
|
|
|
|
+def get_auth_headers(client_id: str, client_secret: str) -> Dict[str, str]:
|
|
|
+ return {"Authorization": BasicAuth(client_id, client_secret).encode()}
|
|
|
+
|
|
|
+
|
|
|
class OAuth2CCSettings(BaseModel):
|
|
|
token_url: AnyHttpUrl
|
|
|
client_id: str
|
|
@@ -53,17 +58,18 @@ class CCTokenGateway:
|
|
|
self.timeout = settings.timeout
|
|
|
self.leeway = settings.leeway
|
|
|
|
|
|
+ auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
|
|
|
+
|
|
|
async def headers_factory():
|
|
|
- auth = BasicAuth(settings.client_id, settings.client_secret)
|
|
|
- return {"Authorization": auth.encode()}
|
|
|
+ return auth_headers
|
|
|
|
|
|
self.provider = ApiProvider(
|
|
|
url=settings.token_url, headers_factory=headers_factory
|
|
|
)
|
|
|
# This binds the cache to the CCTokenGateway instance (and not the class)
|
|
|
- self.cached_headers_factory = alru_cache(self._headers_factory)
|
|
|
+ self.cached_fetch_token = alru_cache(self._fetch_token)
|
|
|
|
|
|
- async def _headers_factory(self) -> str:
|
|
|
+ async def _fetch_token(self) -> str:
|
|
|
response = await self.provider.request(
|
|
|
method="POST",
|
|
|
path="",
|
|
@@ -73,13 +79,16 @@ class CCTokenGateway:
|
|
|
assert response is not None
|
|
|
return response["access_token"]
|
|
|
|
|
|
- async def headers_factory(self) -> str:
|
|
|
- token_str = await self.cached_headers_factory()
|
|
|
+ async def fetch_token(self) -> str:
|
|
|
+ token_str = await self.cached_fetch_token()
|
|
|
if not is_token_usable(token_str, self.leeway):
|
|
|
- self.cached_headers_factory.cache_clear()
|
|
|
- token_str = await self.cached_headers_factory()
|
|
|
+ self.cached_fetch_token.cache_clear()
|
|
|
+ token_str = await self.cached_fetch_token()
|
|
|
return token_str
|
|
|
|
|
|
+ async def fetch_headers(self) -> Dict[str, str]:
|
|
|
+ return {"Authorization": f"Bearer {await self.fetch_token()}"}
|
|
|
+
|
|
|
|
|
|
# Copy-paste of async version:
|
|
|
|
|
@@ -90,17 +99,15 @@ class SyncCCTokenGateway:
|
|
|
self.timeout = settings.timeout
|
|
|
self.leeway = settings.leeway
|
|
|
|
|
|
- def headers_factory():
|
|
|
- auth = BasicAuth(settings.client_id, settings.client_secret)
|
|
|
- return {"Authorization": auth.encode()}
|
|
|
+ auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
|
|
|
|
|
|
self.provider = SyncApiProvider(
|
|
|
- url=settings.token_url, headers_factory=headers_factory
|
|
|
+ url=settings.token_url, headers_factory=lambda: auth_headers
|
|
|
)
|
|
|
# This binds the cache to the SyncCCTokenGateway instance (and not the class)
|
|
|
- self.cached_headers_factory = lru_cache(self._headers_factory)
|
|
|
+ self.cached_fetch_token = lru_cache(self._fetch_token)
|
|
|
|
|
|
- def _headers_factory(self) -> str:
|
|
|
+ def _fetch_token(self) -> str:
|
|
|
response = self.provider.request(
|
|
|
method="POST",
|
|
|
path="",
|
|
@@ -110,9 +117,12 @@ class SyncCCTokenGateway:
|
|
|
assert response is not None
|
|
|
return response["access_token"]
|
|
|
|
|
|
- def headers_factory(self) -> str:
|
|
|
- token_str = self.cached_headers_factory()
|
|
|
+ def fetch_token(self) -> str:
|
|
|
+ token_str = self.cached_fetch_token()
|
|
|
if not is_token_usable(token_str, self.leeway):
|
|
|
- self.cached_headers_factory.cache_clear()
|
|
|
- token_str = self.cached_headers_factory()
|
|
|
+ self.cached_fetch_token.cache_clear()
|
|
|
+ token_str = self.cached_fetch_token()
|
|
|
return token_str
|
|
|
+
|
|
|
+ def fetch_headers(self) -> Dict[str, str]:
|
|
|
+ return {"Authorization": f"Bearer {self.fetch_token()}"}
|