| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 | import base64import jsonimport timefrom functools import lru_cachefrom aiohttp import BasicAuthfrom async_lru import alru_cachefrom pydantic import AnyHttpUrlfrom pydantic import BaseModelfrom clean_python.api_client import ApiProviderfrom clean_python.api_client import SyncApiProvider__all__ = ["CCTokenGateway", "SyncCCTokenGateway", "OAuth2CCSettings"]REFRESH_TIME_DELTA = 5 * 60  # in secondsdef decode_jwt(token):    """Decode a JWT without checking its signature"""    # JWT consists of {header}.{payload}.{signature}    _, payload, _ = token.split(".")    # JWT should be padded with = (base64.b64decode expects this)    payload += "=" * (-len(payload) % 4)    return json.loads(base64.b64decode(payload))def is_token_usable(token: str, leeway: int) -> bool:    """Determine whether the token has expired"""    try:        claims = decode_jwt(token)    except Exception:        return False    exp = claims["exp"]    refresh_on = exp - leeway    return refresh_on >= int(time.time())class OAuth2CCSettings(BaseModel):    token_url: AnyHttpUrl    client_id: str    client_secret: str    scope: str    timeout: float = 1.0  # in seconds    leeway: int = 5 * 60  # in secondsclass CCTokenGateway:    def __init__(self, settings: OAuth2CCSettings):        self.scope = settings.scope        self.timeout = settings.timeout        self.leeway = settings.leeway        async def fetch_token():            auth = BasicAuth(settings.client_id, settings.client_secret)            return {"Authorization": auth.encode()}        self.provider = ApiProvider(url=settings.token_url, fetch_token=fetch_token)        # This binds the cache to the CCTokenGateway instance (and not the class)        self.cached_fetch_token = alru_cache(self._fetch_token)    async def _fetch_token(self) -> str:        response = await self.provider.request(            method="POST",            path="",            fields={"grant_type": "client_credentials", "scope": self.scope},            timeout=self.timeout,        )        assert response is not None        return response["access_token"]    async def fetch_token(self) -> str:        token_str = await self.cached_fetch_token()        if not is_token_usable(token_str, self.leeway):            self.cached_fetch_token.cache_clear()            token_str = await self.cached_fetch_token()        return token_str# Copy-paste of async version:class SyncCCTokenGateway:    def __init__(self, settings: OAuth2CCSettings):        self.scope = settings.scope        self.timeout = settings.timeout        self.leeway = settings.leeway        def fetch_token():            auth = BasicAuth(settings.client_id, settings.client_secret)            return {"Authorization": auth.encode()}        self.provider = SyncApiProvider(url=settings.token_url, fetch_token=fetch_token)        # This binds the cache to the SyncCCTokenGateway instance (and not the class)        self.cached_fetch_token = lru_cache(self._fetch_token)    def _fetch_token(self) -> str:        response = self.provider.request(            method="POST",            path="",            fields={"grant_type": "client_credentials", "scope": self.scope},            timeout=self.timeout,        )        assert response is not None        return response["access_token"]    def fetch_token(self) -> str:        token_str = self.cached_fetch_token()        if not is_token_usable(token_str, self.leeway):            self.cached_fetch_token.cache_clear()            token_str = self.cached_fetch_token()        return token_str
 |