client_credentials.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import base64
  2. import json
  3. import time
  4. from functools import lru_cache
  5. from aiohttp import BasicAuth
  6. from async_lru import alru_cache
  7. from pydantic import AnyHttpUrl
  8. from pydantic import BaseModel
  9. from clean_python.api_client import ApiProvider
  10. from clean_python.api_client import SyncApiProvider
  11. __all__ = ["CCTokenGateway", "SyncCCTokenGateway", "OAuth2CCSettings"]
  12. REFRESH_TIME_DELTA = 5 * 60 # in seconds
  13. def decode_jwt(token):
  14. """Decode a JWT without checking its signature"""
  15. # JWT consists of {header}.{payload}.{signature}
  16. _, payload, _ = token.split(".")
  17. # JWT should be padded with = (base64.b64decode expects this)
  18. payload += "=" * (-len(payload) % 4)
  19. return json.loads(base64.b64decode(payload))
  20. def is_token_usable(token: str, leeway: int) -> bool:
  21. """Determine whether the token has expired"""
  22. try:
  23. claims = decode_jwt(token)
  24. except Exception:
  25. return False
  26. exp = claims["exp"]
  27. refresh_on = exp - leeway
  28. return refresh_on >= int(time.time())
  29. class OAuth2CCSettings(BaseModel):
  30. token_url: AnyHttpUrl
  31. client_id: str
  32. client_secret: str
  33. scope: str
  34. timeout: float = 1.0 # in seconds
  35. leeway: int = 5 * 60 # in seconds
  36. class CCTokenGateway:
  37. def __init__(self, settings: OAuth2CCSettings):
  38. self.scope = settings.scope
  39. self.timeout = settings.timeout
  40. self.leeway = settings.leeway
  41. async def fetch_token():
  42. auth = BasicAuth(settings.client_id, settings.client_secret)
  43. return {"Authorization": auth.encode()}
  44. self.provider = ApiProvider(url=settings.token_url, fetch_token=fetch_token)
  45. # This binds the cache to the CCTokenGateway instance (and not the class)
  46. self.cached_fetch_token = alru_cache(self._fetch_token)
  47. async def _fetch_token(self) -> str:
  48. response = await self.provider.request(
  49. method="POST",
  50. path="",
  51. fields={"grant_type": "client_credentials", "scope": self.scope},
  52. timeout=self.timeout,
  53. )
  54. assert response is not None
  55. return response["access_token"]
  56. async def fetch_token(self) -> str:
  57. token_str = await self.cached_fetch_token()
  58. if not is_token_usable(token_str, self.leeway):
  59. self.cached_fetch_token.cache_clear()
  60. token_str = await self.cached_fetch_token()
  61. return token_str
  62. # Copy-paste of async version:
  63. class SyncCCTokenGateway:
  64. def __init__(self, settings: OAuth2CCSettings):
  65. self.scope = settings.scope
  66. self.timeout = settings.timeout
  67. self.leeway = settings.leeway
  68. def fetch_token():
  69. auth = BasicAuth(settings.client_id, settings.client_secret)
  70. return {"Authorization": auth.encode()}
  71. self.provider = SyncApiProvider(url=settings.token_url, fetch_token=fetch_token)
  72. # This binds the cache to the SyncCCTokenGateway instance (and not the class)
  73. self.cached_fetch_token = lru_cache(self._fetch_token)
  74. def _fetch_token(self) -> str:
  75. response = self.provider.request(
  76. method="POST",
  77. path="",
  78. fields={"grant_type": "client_credentials", "scope": self.scope},
  79. timeout=self.timeout,
  80. )
  81. assert response is not None
  82. return response["access_token"]
  83. def fetch_token(self) -> str:
  84. token_str = self.cached_fetch_token()
  85. if not is_token_usable(token_str, self.leeway):
  86. self.cached_fetch_token.cache_clear()
  87. token_str = self.cached_fetch_token()
  88. return token_str