client_credentials.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import base64
  2. import json
  3. import time
  4. from functools import lru_cache
  5. from aiohttp import BasicAuth
  6. from async_lru import alru_cache
  7. from pydantic import AnyHttpUrl
  8. from pydantic import BaseModel
  9. from clean_python.api_client import ApiProvider
  10. from clean_python.api_client import SyncApiProvider
  11. __all__ = ["CCTokenGateway", "SyncCCTokenGateway", "OAuth2CCSettings"]
  12. REFRESH_TIME_DELTA = 5 * 60 # in seconds
  13. def decode_jwt(token):
  14. """Decode a JWT without checking its signature"""
  15. # JWT consists of {header}.{payload}.{signature}
  16. _, payload, _ = token.split(".")
  17. # JWT should be padded with = (base64.b64decode expects this)
  18. payload += "=" * (-len(payload) % 4)
  19. return json.loads(base64.b64decode(payload))
  20. def is_token_usable(token: str, leeway: int) -> bool:
  21. """Determine whether the token has expired"""
  22. try:
  23. claims = decode_jwt(token)
  24. except Exception:
  25. return False
  26. exp = claims["exp"]
  27. refresh_on = exp - leeway
  28. return refresh_on >= int(time.time())
  29. class OAuth2CCSettings(BaseModel):
  30. token_url: AnyHttpUrl
  31. client_id: str
  32. client_secret: str
  33. scope: str
  34. timeout: float = 1.0 # in seconds
  35. leeway: int = 5 * 60 # in seconds
  36. class CCTokenGateway:
  37. def __init__(self, settings: OAuth2CCSettings):
  38. self.scope = settings.scope
  39. self.timeout = settings.timeout
  40. self.leeway = settings.leeway
  41. async def headers_factory():
  42. auth = BasicAuth(settings.client_id, settings.client_secret)
  43. return {"Authorization": auth.encode()}
  44. self.provider = ApiProvider(
  45. url=settings.token_url, headers_factory=headers_factory
  46. )
  47. # This binds the cache to the CCTokenGateway instance (and not the class)
  48. self.cached_headers_factory = alru_cache(self._headers_factory)
  49. async def _headers_factory(self) -> str:
  50. response = await self.provider.request(
  51. method="POST",
  52. path="",
  53. fields={"grant_type": "client_credentials", "scope": self.scope},
  54. timeout=self.timeout,
  55. )
  56. assert response is not None
  57. return response["access_token"]
  58. async def headers_factory(self) -> str:
  59. token_str = await self.cached_headers_factory()
  60. if not is_token_usable(token_str, self.leeway):
  61. self.cached_headers_factory.cache_clear()
  62. token_str = await self.cached_headers_factory()
  63. return token_str
  64. # Copy-paste of async version:
  65. class SyncCCTokenGateway:
  66. def __init__(self, settings: OAuth2CCSettings):
  67. self.scope = settings.scope
  68. self.timeout = settings.timeout
  69. self.leeway = settings.leeway
  70. def headers_factory():
  71. auth = BasicAuth(settings.client_id, settings.client_secret)
  72. return {"Authorization": auth.encode()}
  73. self.provider = SyncApiProvider(
  74. url=settings.token_url, headers_factory=headers_factory
  75. )
  76. # This binds the cache to the SyncCCTokenGateway instance (and not the class)
  77. self.cached_headers_factory = lru_cache(self._headers_factory)
  78. def _headers_factory(self) -> str:
  79. response = self.provider.request(
  80. method="POST",
  81. path="",
  82. fields={"grant_type": "client_credentials", "scope": self.scope},
  83. timeout=self.timeout,
  84. )
  85. assert response is not None
  86. return response["access_token"]
  87. def headers_factory(self) -> str:
  88. token_str = self.cached_headers_factory()
  89. if not is_token_usable(token_str, self.leeway):
  90. self.cached_headers_factory.cache_clear()
  91. token_str = self.cached_headers_factory()
  92. return token_str