token_verifier.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # (c) Nelen & Schuurmans
  2. import logging
  3. import socket
  4. from typing import Any
  5. from typing import Dict
  6. from typing import FrozenSet
  7. from typing import List
  8. from typing import Optional
  9. import jwt
  10. from jwt import PyJWKClient
  11. from jwt.exceptions import PyJWTError
  12. from pydantic import AnyHttpUrl
  13. from pydantic import BaseModel
  14. from pydantic import ValidationError
  15. from clean_python import PermissionDenied
  16. from clean_python import Unauthorized
  17. from clean_python import User
  18. from .token import Token
  19. __all__ = [
  20. "BaseTokenVerifier",
  21. "TokenVerifier",
  22. "NoAuthTokenVerifier",
  23. "TokenVerifierSettings",
  24. "OAuth2SPAClientSettings",
  25. ]
  26. logger = logging.getLogger(__name__)
  27. class TokenVerifierSettings(BaseModel):
  28. issuer: str
  29. algorithms: List[str] = ["RS256"]
  30. # optional additional checks:
  31. scope: Optional[str] = None
  32. admin_users: Optional[List[str]] = None # 'sub' whitelist
  33. jwks_timeout: float = 1.0
  34. class OAuth2SPAClientSettings(BaseModel):
  35. client_id: str
  36. token_url: AnyHttpUrl
  37. authorization_url: AnyHttpUrl
  38. class BaseTokenVerifier:
  39. def force(self, token: Token) -> None:
  40. raise NotImplementedError()
  41. def __call__(self, authorization: Optional[str]) -> Token:
  42. raise NotImplementedError()
  43. class NoAuthTokenVerifier(BaseTokenVerifier):
  44. def __init__(self):
  45. self.token = Token(
  46. claims={"sub": "DEV", "username": "dev", "scope": "superuser"}
  47. )
  48. def force(self, token: Token) -> None:
  49. self.token = token
  50. def __call__(self, authorization: Optional[str]) -> Token:
  51. return self.token
  52. class TokenVerifier(BaseTokenVerifier):
  53. """A class for verifying OAuth2 Access Tokens from AWS Cognito
  54. The verification steps followed are documented here:
  55. https://docs.aws.amazon.com/cognito/latest/developerguide/amazon- ⏎
  56. cognito-user-pools-using-tokens-verifying-a-jwt.html
  57. """
  58. # allow 2 minutes leeway for verifying token expiry:
  59. LEEWAY = 120
  60. def __init__(
  61. self, settings: TokenVerifierSettings, logger: Optional[logging.Logger] = None
  62. ):
  63. self.settings = settings
  64. self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
  65. def __call__(self, authorization: Optional[str]) -> Token:
  66. # Step 0: retrieve the token from the Authorization header
  67. # See https://tools.ietf.org/html/rfc6750#section-2.1,
  68. # Bearer is case-sensitive and there is exactly 1 separator after.
  69. if authorization is None:
  70. raise Unauthorized("Missing Authorization header")
  71. jwt_str = authorization[7:] if authorization.startswith("Bearer") else None
  72. if jwt_str is None:
  73. raise Unauthorized("Authorization does not start with 'Bearer '")
  74. # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
  75. # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
  76. try:
  77. key = self.get_key(jwt_str, self.settings.jwks_timeout) # JSON Web Key
  78. except PyJWTError as e:
  79. raise Unauthorized(f"Token is invalid: {e}")
  80. # Step 2: Validate the JWT signature and standard claims
  81. try:
  82. claims = jwt.decode(
  83. jwt_str,
  84. key.key,
  85. algorithms=self.settings.algorithms,
  86. issuer=self.settings.issuer,
  87. leeway=self.LEEWAY,
  88. options={
  89. "require": ["exp", "iss", "sub", "scope", "token_use"],
  90. },
  91. )
  92. except PyJWTError as e:
  93. raise Unauthorized(f"Token is invalid: {e}")
  94. # Step 3: Verify additional claims. At this point, we have passed
  95. # verification, so unverified claims may be used safely.
  96. self.verify_token_use(claims)
  97. try:
  98. token = Token(claims=claims)
  99. except ValidationError as e:
  100. raise Unauthorized(f"Token is invalid: {e}")
  101. self.verify_scope(token.scope)
  102. # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
  103. self.authorize_user(token.user)
  104. return token
  105. def get_key(self, token: str, timeout: float = 1.0) -> jwt.PyJWK:
  106. """Return the JSON Web KEY (JWK) corresponding to kid."""
  107. # NB: pyjwt does not allow timeouts, but we can set it using the
  108. # global value
  109. old_timeout = socket.getdefaulttimeout()
  110. try:
  111. socket.setdefaulttimeout(timeout)
  112. return self.jwk_client.get_signing_key_from_jwt(token)
  113. finally:
  114. socket.setdefaulttimeout(old_timeout)
  115. def verify_token_use(self, claims: Dict[str, Any]) -> None:
  116. """Check the token_use claim."""
  117. if claims["token_use"] != "access":
  118. raise Unauthorized(
  119. f"Token has invalid token_use claim: {claims['token_use']}"
  120. )
  121. def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
  122. """Parse scopes and optionally check scope claim."""
  123. if self.settings.scope is None:
  124. return
  125. if self.settings.scope not in claims_scope:
  126. raise Unauthorized(f"Token is missing '{self.settings.scope}' scope")
  127. def authorize_user(self, user: User) -> None:
  128. if self.settings.admin_users is None:
  129. return
  130. if user.id not in self.settings.admin_users:
  131. raise PermissionDenied()