123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import base64
- import json
- import time
- from functools import lru_cache
- from typing import Dict
- from aiohttp import BasicAuth
- from async_lru import alru_cache
- from pydantic import AnyHttpUrl
- from pydantic import BaseModel
- from clean_python.api_client import ApiProvider
- from clean_python.api_client import SyncApiProvider
- __all__ = ["CCTokenGateway", "SyncCCTokenGateway", "OAuth2CCSettings"]
- REFRESH_TIME_DELTA = 5 * 60 # in seconds
- def decode_jwt(token):
- """Decode a JWT without checking its signature"""
- # JWT consists of {header}.{payload}.{signature}
- _, payload, _ = token.split(".")
- # JWT should be padded with = (base64.b64decode expects this)
- payload += "=" * (-len(payload) % 4)
- return json.loads(base64.b64decode(payload))
- def is_token_usable(token: str, leeway: int) -> bool:
- """Determine whether the token has expired"""
- try:
- claims = decode_jwt(token)
- except Exception:
- return False
- exp = claims["exp"]
- refresh_on = exp - leeway
- return refresh_on >= int(time.time())
- def get_auth_headers(client_id: str, client_secret: str) -> Dict[str, str]:
- return {"Authorization": BasicAuth(client_id, client_secret).encode()}
- class OAuth2CCSettings(BaseModel):
- token_url: AnyHttpUrl
- client_id: str
- client_secret: str
- scope: str
- timeout: float = 1.0 # in seconds
- leeway: int = 5 * 60 # in seconds
- class CCTokenGateway:
- def __init__(self, settings: OAuth2CCSettings):
- self.scope = settings.scope
- self.timeout = settings.timeout
- self.leeway = settings.leeway
- auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
- async def headers_factory():
- return auth_headers
- self.provider = ApiProvider(
- url=settings.token_url, headers_factory=headers_factory
- )
- # This binds the cache to the CCTokenGateway instance (and not the class)
- self.cached_fetch_token = alru_cache(self._fetch_token)
- async def _fetch_token(self) -> str:
- response = await self.provider.request(
- method="POST",
- path="",
- fields={"grant_type": "client_credentials", "scope": self.scope},
- timeout=self.timeout,
- )
- assert response is not None
- return response["access_token"]
- async def fetch_token(self) -> str:
- token_str = await self.cached_fetch_token()
- if not is_token_usable(token_str, self.leeway):
- self.cached_fetch_token.cache_clear()
- token_str = await self.cached_fetch_token()
- return token_str
- async def fetch_headers(self) -> Dict[str, str]:
- return {"Authorization": f"Bearer {await self.fetch_token()}"}
- # Copy-paste of async version:
- class SyncCCTokenGateway:
- def __init__(self, settings: OAuth2CCSettings):
- self.scope = settings.scope
- self.timeout = settings.timeout
- self.leeway = settings.leeway
- auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
- self.provider = SyncApiProvider(
- url=settings.token_url, headers_factory=lambda: auth_headers
- )
- # This binds the cache to the SyncCCTokenGateway instance (and not the class)
- self.cached_fetch_token = lru_cache(self._fetch_token)
- def _fetch_token(self) -> str:
- response = self.provider.request(
- method="POST",
- path="",
- fields={"grant_type": "client_credentials", "scope": self.scope},
- timeout=self.timeout,
- )
- assert response is not None
- return response["access_token"]
- def fetch_token(self) -> str:
- token_str = self.cached_fetch_token()
- if not is_token_usable(token_str, self.leeway):
- self.cached_fetch_token.cache_clear()
- token_str = self.cached_fetch_token()
- return token_str
- def fetch_headers(self) -> Dict[str, str]:
- return {"Authorization": f"Bearer {self.fetch_token()}"}
|