oauth2.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # (c) Nelen & Schuurmans
  2. from typing import Dict
  3. from typing import FrozenSet
  4. from typing import List
  5. from typing import Optional
  6. import jwt
  7. from jwt import PyJWKClient
  8. from jwt.exceptions import PyJWTError
  9. from pydantic import AnyHttpUrl
  10. from pydantic import BaseModel
  11. from clean_python import PermissionDenied
  12. from clean_python import Unauthorized
  13. from .claims import Claims
  14. from .claims import Tenant
  15. from .claims import User
  16. __all__ = ["TokenVerifier", "TokenVerifierSettings", "OAuth2SPAClientSettings"]
  17. class TokenVerifierSettings(BaseModel):
  18. issuer: str
  19. algorithms: List[str] = ["RS256"]
  20. # optional additional checks:
  21. scope: Optional[str] = None
  22. admin_users: Optional[List[str]] = None # 'sub' whitelist
  23. class OAuth2SPAClientSettings(BaseModel):
  24. client_id: str
  25. token_url: AnyHttpUrl
  26. authorization_url: AnyHttpUrl
  27. class TokenVerifier:
  28. """A class for verifying OAuth2 Access Tokens from AWS Cognito
  29. The verification steps followed are documented here:
  30. https://docs.aws.amazon.com/cognito/latest/developerguide/amazon- ⏎
  31. cognito-user-pools-using-tokens-verifying-a-jwt.html
  32. """
  33. # allow 2 minutes leeway for verifying token expiry:
  34. LEEWAY = 120
  35. def __init__(self, settings: TokenVerifierSettings):
  36. self.settings = settings
  37. self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
  38. def __call__(self, authorization: Optional[str]) -> Claims:
  39. # Step 0: retrieve the token from the Authorization header
  40. # See https://tools.ietf.org/html/rfc6750#section-2.1,
  41. # Bearer is case-sensitive and there is exactly 1 separator after.
  42. if authorization is None:
  43. raise Unauthorized()
  44. token = authorization[7:] if authorization.startswith("Bearer") else None
  45. if token is None:
  46. raise Unauthorized()
  47. # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
  48. # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
  49. try:
  50. key = self.get_key(token) # JSON Web Key
  51. except PyJWTError:
  52. # logger.info("Token is invalid: %s", e)
  53. raise Unauthorized()
  54. # Step 2: Validate the JWT signature and standard claims
  55. try:
  56. claims = jwt.decode(
  57. token,
  58. key.key,
  59. algorithms=self.settings.algorithms,
  60. issuer=self.settings.issuer,
  61. leeway=self.LEEWAY,
  62. options={
  63. "require": ["exp", "iss", "sub", "scope", "token_use"],
  64. },
  65. )
  66. except PyJWTError:
  67. # logger.info("Token is invalid: %s", e)
  68. raise Unauthorized()
  69. # Step 3: Verify additional claims. At this point, we have passed
  70. # verification, so unverified claims may be used safely.
  71. user = self.parse_user(claims)
  72. tenant = self.parse_tenant(claims)
  73. scope = self.parse_scope(claims)
  74. self.verify_token_use(claims)
  75. self.verify_scope(scope)
  76. # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
  77. self.authorize_user(user)
  78. return Claims(user=user, scope=scope, tenant=tenant)
  79. def get_key(self, token) -> jwt.PyJWK:
  80. """Return the JSON Web KEY (JWK) corresponding to kid."""
  81. return self.jwk_client.get_signing_key_from_jwt(token)
  82. def verify_token_use(self, claims: Dict) -> None:
  83. """Check the token_use claim."""
  84. if claims["token_use"] != "access":
  85. # logger.info("Token has invalid token_use claim: %s", claims["token_use"])
  86. raise Unauthorized()
  87. def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
  88. """Parse scopes and optionally check scope claim."""
  89. if self.settings.scope is None:
  90. return
  91. if self.settings.scope not in claims_scope:
  92. # logger.info("Token has invalid scope claim: %s", claims["scope"])
  93. raise Unauthorized()
  94. def authorize_user(self, user: User) -> None:
  95. if self.settings.admin_users is None:
  96. return
  97. if user.id not in self.settings.admin_users:
  98. # logger.info("User with sub %s is not authorized", claims.get("sub"))
  99. raise PermissionDenied()
  100. def parse_user(self, claims: Dict) -> User:
  101. return User(id=claims["sub"], name=claims.get("username"))
  102. def parse_scope(self, claims: Dict) -> FrozenSet[str]:
  103. return frozenset(claims["scope"].split(" "))
  104. def parse_tenant(self, claims: Dict) -> Optional[Tenant]:
  105. if claims.get("tenant"):
  106. tenant = Tenant(id=claims["tenant"], name=claims.get("tenant_name", ""))
  107. else:
  108. tenant = None
  109. return tenant