client_credentials.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import base64
  2. import json
  3. import time
  4. from functools import lru_cache
  5. from typing import Dict
  6. from aiohttp import BasicAuth
  7. from async_lru import alru_cache
  8. from pydantic import AnyHttpUrl
  9. from pydantic import BaseModel
  10. from clean_python.api_client import ApiProvider
  11. from clean_python.api_client import SyncApiProvider
  12. __all__ = ["CCTokenGateway", "SyncCCTokenGateway", "OAuth2CCSettings"]
  13. REFRESH_TIME_DELTA = 5 * 60 # in seconds
  14. def decode_jwt(token):
  15. """Decode a JWT without checking its signature"""
  16. # JWT consists of {header}.{payload}.{signature}
  17. _, payload, _ = token.split(".")
  18. # JWT should be padded with = (base64.b64decode expects this)
  19. payload += "=" * (-len(payload) % 4)
  20. return json.loads(base64.b64decode(payload))
  21. def is_token_usable(token: str, leeway: int) -> bool:
  22. """Determine whether the token has expired"""
  23. try:
  24. claims = decode_jwt(token)
  25. except Exception:
  26. return False
  27. exp = claims["exp"]
  28. refresh_on = exp - leeway
  29. return refresh_on >= int(time.time())
  30. def get_auth_headers(client_id: str, client_secret: str) -> Dict[str, str]:
  31. return {"Authorization": BasicAuth(client_id, client_secret).encode()}
  32. class OAuth2CCSettings(BaseModel):
  33. token_url: AnyHttpUrl
  34. client_id: str
  35. client_secret: str
  36. scope: str
  37. timeout: float = 1.0 # in seconds
  38. leeway: int = 5 * 60 # in seconds
  39. class CCTokenGateway:
  40. def __init__(self, settings: OAuth2CCSettings):
  41. self.scope = settings.scope
  42. self.timeout = settings.timeout
  43. self.leeway = settings.leeway
  44. auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
  45. async def headers_factory():
  46. return auth_headers
  47. self.provider = ApiProvider(
  48. url=settings.token_url, headers_factory=headers_factory
  49. )
  50. # This binds the cache to the CCTokenGateway instance (and not the class)
  51. self.cached_fetch_token = alru_cache(self._fetch_token)
  52. async def _fetch_token(self) -> str:
  53. response = await self.provider.request(
  54. method="POST",
  55. path="",
  56. fields={"grant_type": "client_credentials", "scope": self.scope},
  57. timeout=self.timeout,
  58. )
  59. assert response is not None
  60. return response["access_token"]
  61. async def fetch_token(self) -> str:
  62. token_str = await self.cached_fetch_token()
  63. if not is_token_usable(token_str, self.leeway):
  64. self.cached_fetch_token.cache_clear()
  65. token_str = await self.cached_fetch_token()
  66. return token_str
  67. async def fetch_headers(self) -> Dict[str, str]:
  68. return {"Authorization": f"Bearer {await self.fetch_token()}"}
  69. # Copy-paste of async version:
  70. class SyncCCTokenGateway:
  71. def __init__(self, settings: OAuth2CCSettings):
  72. self.scope = settings.scope
  73. self.timeout = settings.timeout
  74. self.leeway = settings.leeway
  75. auth_headers = get_auth_headers(settings.client_id, settings.client_secret)
  76. self.provider = SyncApiProvider(
  77. url=settings.token_url, headers_factory=lambda: auth_headers
  78. )
  79. # This binds the cache to the SyncCCTokenGateway instance (and not the class)
  80. self.cached_fetch_token = lru_cache(self._fetch_token)
  81. def _fetch_token(self) -> str:
  82. response = self.provider.request(
  83. method="POST",
  84. path="",
  85. fields={"grant_type": "client_credentials", "scope": self.scope},
  86. timeout=self.timeout,
  87. )
  88. assert response is not None
  89. return response["access_token"]
  90. def fetch_token(self) -> str:
  91. token_str = self.cached_fetch_token()
  92. if not is_token_usable(token_str, self.leeway):
  93. self.cached_fetch_token.cache_clear()
  94. token_str = self.cached_fetch_token()
  95. return token_str
  96. def fetch_headers(self) -> Dict[str, str]:
  97. return {"Authorization": f"Bearer {self.fetch_token()}"}