token_verifier.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # (c) Nelen & Schuurmans
  2. import logging
  3. from typing import Any
  4. from typing import Dict
  5. from typing import FrozenSet
  6. from typing import List
  7. from typing import Optional
  8. import jwt
  9. from jwt import PyJWKClient
  10. from jwt.exceptions import PyJWTError
  11. from pydantic import AnyHttpUrl
  12. from pydantic import BaseModel
  13. from pydantic import ValidationError
  14. from clean_python import PermissionDenied
  15. from clean_python import Unauthorized
  16. from clean_python import User
  17. from .token import Token
  18. __all__ = [
  19. "BaseTokenVerifier",
  20. "TokenVerifier",
  21. "NoAuthTokenVerifier",
  22. "TokenVerifierSettings",
  23. "OAuth2SPAClientSettings",
  24. ]
  25. logger = logging.getLogger(__name__)
  26. class TokenVerifierSettings(BaseModel):
  27. issuer: str
  28. algorithms: List[str] = ["RS256"]
  29. # optional additional checks:
  30. scope: Optional[str] = None
  31. admin_users: Optional[List[str]] = None # 'sub' whitelist
  32. class OAuth2SPAClientSettings(BaseModel):
  33. client_id: str
  34. token_url: AnyHttpUrl
  35. authorization_url: AnyHttpUrl
  36. class BaseTokenVerifier:
  37. def force(self, token: Token) -> None:
  38. raise NotImplementedError()
  39. def __call__(self, authorization: Optional[str]) -> Token:
  40. raise NotImplementedError()
  41. class NoAuthTokenVerifier(BaseTokenVerifier):
  42. def __init__(self):
  43. self.token = Token(
  44. claims={"sub": "DEV", "username": "dev", "scope": "superuser"}
  45. )
  46. def force(self, token: Token) -> None:
  47. self.token = token
  48. def __call__(self, authorization: Optional[str]) -> Token:
  49. return self.token
  50. class TokenVerifier(BaseTokenVerifier):
  51. """A class for verifying OAuth2 Access Tokens from AWS Cognito
  52. The verification steps followed are documented here:
  53. https://docs.aws.amazon.com/cognito/latest/developerguide/amazon- ⏎
  54. cognito-user-pools-using-tokens-verifying-a-jwt.html
  55. """
  56. # allow 2 minutes leeway for verifying token expiry:
  57. LEEWAY = 120
  58. def __init__(
  59. self, settings: TokenVerifierSettings, logger: Optional[logging.Logger] = None
  60. ):
  61. self.settings = settings
  62. self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
  63. def __call__(self, authorization: Optional[str]) -> Token:
  64. # Step 0: retrieve the token from the Authorization header
  65. # See https://tools.ietf.org/html/rfc6750#section-2.1,
  66. # Bearer is case-sensitive and there is exactly 1 separator after.
  67. if authorization is None:
  68. logger.info("Missing Authorization header")
  69. raise Unauthorized()
  70. jwt_str = authorization[7:] if authorization.startswith("Bearer") else None
  71. if jwt_str is None:
  72. logger.info("Authorization does not start with 'Bearer '")
  73. raise Unauthorized()
  74. # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
  75. # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
  76. try:
  77. key = self.get_key(jwt_str) # JSON Web Key
  78. except PyJWTError as e:
  79. logger.info("Token is invalid: %s", e)
  80. raise Unauthorized()
  81. # Step 2: Validate the JWT signature and standard claims
  82. try:
  83. claims = jwt.decode(
  84. jwt_str,
  85. key.key,
  86. algorithms=self.settings.algorithms,
  87. issuer=self.settings.issuer,
  88. leeway=self.LEEWAY,
  89. options={
  90. "require": ["exp", "iss", "sub", "scope", "token_use"],
  91. },
  92. )
  93. except PyJWTError as e:
  94. logger.info("Token is invalid: %s", e)
  95. raise Unauthorized()
  96. # Step 3: Verify additional claims. At this point, we have passed
  97. # verification, so unverified claims may be used safely.
  98. self.verify_token_use(claims)
  99. try:
  100. token = Token(claims=claims)
  101. except ValidationError as e:
  102. logger.info("Token is invalid: %s", e)
  103. raise Unauthorized()
  104. self.verify_scope(token.scope)
  105. # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
  106. self.authorize_user(token.user)
  107. return token
  108. def get_key(self, token) -> jwt.PyJWK:
  109. """Return the JSON Web KEY (JWK) corresponding to kid."""
  110. return self.jwk_client.get_signing_key_from_jwt(token)
  111. def verify_token_use(self, claims: Dict[str, Any]) -> None:
  112. """Check the token_use claim."""
  113. if claims["token_use"] != "access":
  114. logger.info("Token has invalid token_use claim: %s", claims["token_use"])
  115. raise Unauthorized()
  116. def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
  117. """Parse scopes and optionally check scope claim."""
  118. if self.settings.scope is None:
  119. return
  120. if self.settings.scope not in claims_scope:
  121. logger.info("Token is missing '%s' scope", self.settings.scope)
  122. raise Unauthorized()
  123. def authorize_user(self, user: User) -> None:
  124. if self.settings.admin_users is None:
  125. return
  126. if user.id not in self.settings.admin_users:
  127. raise PermissionDenied()