oauth2.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. import logging
  4. from typing import Dict, List
  5. import jwt
  6. from jwt import PyJWKClient
  7. from jwt.exceptions import PyJWTError
  8. from pydantic import AnyHttpUrl, BaseModel
  9. from clean_python.base.domain.exceptions import PermissionDenied, Unauthorized
  10. __all__ = ["OAuth2Settings", "OAuth2AccessTokenVerifier"]
  11. logger = logging.getLogger(__name__)
  12. class OAuth2Settings(BaseModel):
  13. client_id: str
  14. issuer: str
  15. resource_server_id: str
  16. token_url: AnyHttpUrl
  17. authorization_url: AnyHttpUrl
  18. algorithms: List[str] = ["RS256"]
  19. admin_users: List[str]
  20. class OAuth2AccessTokenVerifier:
  21. """A class for verifying OAuth2 Access Tokens from AWS Cognito
  22. The verification steps followed are documented here:
  23. https://docs.aws.amazon.com/cognito/latest/developerguide/amazon- ⏎
  24. cognito-user-pools-using-tokens-verifying-a-jwt.html
  25. """
  26. # allow 2 minutes leeway for verifying token expiry:
  27. LEEWAY = 120
  28. def __init__(
  29. self,
  30. scope: str,
  31. issuer: str,
  32. resource_server_id: str,
  33. algorithms: List[str],
  34. admin_users: List[str],
  35. ):
  36. self.scope = scope
  37. self.issuer = issuer
  38. self.algorithms = algorithms
  39. self.resource_server_id = resource_server_id
  40. self.admin_users = admin_users
  41. self.jwk_client = PyJWKClient(f"{issuer}/.well-known/jwks.json")
  42. def __call__(self, token: str) -> Dict:
  43. # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
  44. # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
  45. try:
  46. key = self.get_key(token) # JSON Web Key
  47. except PyJWTError as e:
  48. logger.info("Token is invalid: %s", e)
  49. raise Unauthorized()
  50. # Step 2: Validate the JWT signature and standard claims
  51. try:
  52. claims = jwt.decode(
  53. token,
  54. key.key,
  55. algorithms=self.algorithms,
  56. issuer=self.issuer,
  57. leeway=self.LEEWAY,
  58. options={
  59. "require": ["exp", "iss", "sub", "scope", "token_use"],
  60. },
  61. )
  62. except PyJWTError as e:
  63. logger.info("Token is invalid: %s", e)
  64. raise Unauthorized()
  65. # Step 3: Verify additional claims. At this point, we have passed
  66. # verification, so unverified claims may be used safely.
  67. self.verify_token_use(claims)
  68. self.verify_scope(claims)
  69. # Step 4: Authorization: we currently work with a hardcoded
  70. # list of users ('sub' claims)
  71. self.authorize(claims)
  72. return claims
  73. def get_key(self, token) -> jwt.PyJWK:
  74. """Return the JSON Web KEY (JWK) corresponding to kid."""
  75. return self.jwk_client.get_signing_key_from_jwt(token)
  76. def verify_token_use(self, claims):
  77. """Check the token_use claim."""
  78. if claims["token_use"] != "access":
  79. logger.info("Token has invalid token_use claim: %s", claims["token_use"])
  80. raise Unauthorized()
  81. def verify_scope(self, claims):
  82. """Check scope claim.
  83. Cognito includes the resource server id inside the scope, like this:
  84. raster.lizard.net/*.readwrite
  85. """
  86. if f"{self.resource_server_id}{self.scope}" not in claims["scope"].split(" "):
  87. logger.info("Token has invalid scope claim: %s", claims["scope"])
  88. raise Unauthorized()
  89. def authorize(self, claims):
  90. """The subject (sub) claim should be in a hard-coded whitelist."""
  91. if claims.get("sub") not in self.admin_users:
  92. logger.info("User with sub %s is not authorized", claims.get("sub"))
  93. raise PermissionDenied()