oauth2.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. import logging
  4. from typing import Dict
  5. from typing import List
  6. import jwt
  7. from jwt import PyJWKClient
  8. from jwt.exceptions import PyJWTError
  9. from pydantic import AnyHttpUrl
  10. from pydantic import BaseModel
  11. from clean_python.base.domain.exceptions import PermissionDenied
  12. from clean_python.base.domain.exceptions import Unauthorized
  13. __all__ = ["OAuth2Settings", "OAuth2AccessTokenVerifier"]
  14. logger = logging.getLogger(__name__)
  15. class OAuth2Settings(BaseModel):
  16. client_id: str
  17. issuer: str
  18. resource_server_id: str
  19. token_url: AnyHttpUrl
  20. authorization_url: AnyHttpUrl
  21. algorithms: List[str] = ["RS256"]
  22. admin_users: List[str]
  23. class OAuth2AccessTokenVerifier:
  24. """A class for verifying OAuth2 Access Tokens from AWS Cognito
  25. The verification steps followed are documented here:
  26. https://docs.aws.amazon.com/cognito/latest/developerguide/amazon- ⏎
  27. cognito-user-pools-using-tokens-verifying-a-jwt.html
  28. """
  29. # allow 2 minutes leeway for verifying token expiry:
  30. LEEWAY = 120
  31. def __init__(
  32. self,
  33. scope: str,
  34. issuer: str,
  35. resource_server_id: str,
  36. algorithms: List[str],
  37. admin_users: List[str],
  38. ):
  39. self.scope = scope
  40. self.issuer = issuer
  41. self.algorithms = algorithms
  42. self.resource_server_id = resource_server_id
  43. self.admin_users = admin_users
  44. self.jwk_client = PyJWKClient(f"{issuer}/.well-known/jwks.json")
  45. def __call__(self, token: str) -> Dict:
  46. # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
  47. # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
  48. try:
  49. key = self.get_key(token) # JSON Web Key
  50. except PyJWTError as e:
  51. logger.info("Token is invalid: %s", e)
  52. raise Unauthorized()
  53. # Step 2: Validate the JWT signature and standard claims
  54. try:
  55. claims = jwt.decode(
  56. token,
  57. key.key,
  58. algorithms=self.algorithms,
  59. issuer=self.issuer,
  60. leeway=self.LEEWAY,
  61. options={
  62. "require": ["exp", "iss", "sub", "scope", "token_use"],
  63. },
  64. )
  65. except PyJWTError as e:
  66. logger.info("Token is invalid: %s", e)
  67. raise Unauthorized()
  68. # Step 3: Verify additional claims. At this point, we have passed
  69. # verification, so unverified claims may be used safely.
  70. self.verify_token_use(claims)
  71. self.verify_scope(claims)
  72. # Step 4: Authorization: we currently work with a hardcoded
  73. # list of users ('sub' claims)
  74. self.authorize(claims)
  75. return claims
  76. def get_key(self, token) -> jwt.PyJWK:
  77. """Return the JSON Web KEY (JWK) corresponding to kid."""
  78. return self.jwk_client.get_signing_key_from_jwt(token)
  79. def verify_token_use(self, claims):
  80. """Check the token_use claim."""
  81. if claims["token_use"] != "access":
  82. logger.info("Token has invalid token_use claim: %s", claims["token_use"])
  83. raise Unauthorized()
  84. def verify_scope(self, claims):
  85. """Check scope claim.
  86. Cognito includes the resource server id inside the scope, like this:
  87. raster.lizard.net/*.readwrite
  88. """
  89. if f"{self.resource_server_id}{self.scope}" not in claims["scope"].split(" "):
  90. logger.info("Token has invalid scope claim: %s", claims["scope"])
  91. raise Unauthorized()
  92. def authorize(self, claims):
  93. """The subject (sub) claim should be in a hard-coded whitelist."""
  94. if claims.get("sub") not in self.admin_users:
  95. logger.info("User with sub %s is not authorized", claims.get("sub"))
  96. raise PermissionDenied()