| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 | # (c) Nelen & Schuurmansimport loggingimport socketfrom typing import Anyfrom typing import Dictfrom typing import FrozenSetfrom typing import Listfrom typing import Optionalimport jwtfrom jwt import PyJWKClientfrom jwt.exceptions import PyJWTErrorfrom pydantic import AnyHttpUrlfrom pydantic import BaseModelfrom pydantic import ValidationErrorfrom clean_python import PermissionDeniedfrom clean_python import Unauthorizedfrom clean_python import Userfrom .token import Token__all__ = [    "BaseTokenVerifier",    "TokenVerifier",    "NoAuthTokenVerifier",    "TokenVerifierSettings",    "OAuth2SPAClientSettings",]logger = logging.getLogger(__name__)class TokenVerifierSettings(BaseModel):    issuer: str    algorithms: List[str] = ["RS256"]    # optional additional checks:    scope: Optional[str] = None    admin_users: Optional[List[str]] = None  # 'sub' whitelist    jwks_timeout: float = 1.0class OAuth2SPAClientSettings(BaseModel):    client_id: str    token_url: AnyHttpUrl    authorization_url: AnyHttpUrlclass BaseTokenVerifier:    def force(self, token: Token) -> None:        raise NotImplementedError()    def __call__(self, authorization: Optional[str]) -> Token:        raise NotImplementedError()class NoAuthTokenVerifier(BaseTokenVerifier):    def __init__(self):        self.token = Token(            claims={"sub": "DEV", "username": "dev", "scope": "superuser"}        )    def force(self, token: Token) -> None:        self.token = token    def __call__(self, authorization: Optional[str]) -> Token:        return self.tokenclass TokenVerifier(BaseTokenVerifier):    """A class for verifying OAuth2 Access Tokens from AWS Cognito    The verification steps followed are documented here:    https://docs.aws.amazon.com/cognito/latest/developerguide/amazon- ⏎    cognito-user-pools-using-tokens-verifying-a-jwt.html    """    # allow 2 minutes leeway for verifying token expiry:    LEEWAY = 120    def __init__(        self, settings: TokenVerifierSettings, logger: Optional[logging.Logger] = None    ):        self.settings = settings        self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")    def __call__(self, authorization: Optional[str]) -> Token:        # Step 0: retrieve the token from the Authorization header        # See https://tools.ietf.org/html/rfc6750#section-2.1,        # Bearer is case-sensitive and there is exactly 1 separator after.        if authorization is None:            raise Unauthorized("Missing Authorization header")        jwt_str = authorization[7:] if authorization.startswith("Bearer") else None        if jwt_str is None:            raise Unauthorized("Authorization does not start with 'Bearer '")        # Step 1: Confirm the structure of the JWT. This check is part of get_kid since        # jwt.get_unverified_header will raise a JWTError if the structure is wrong.        try:            key = self.get_key(jwt_str, self.settings.jwks_timeout)  # JSON Web Key        except PyJWTError as e:            raise Unauthorized(f"Token is invalid: {e}")        # Step 2: Validate the JWT signature and standard claims        try:            claims = jwt.decode(                jwt_str,                key.key,                algorithms=self.settings.algorithms,                issuer=self.settings.issuer,                leeway=self.LEEWAY,                options={                    "require": ["exp", "iss", "sub", "scope", "token_use"],                },            )        except PyJWTError as e:            raise Unauthorized(f"Token is invalid: {e}")        # Step 3: Verify additional claims. At this point, we have passed        # verification, so unverified claims may be used safely.        self.verify_token_use(claims)        try:            token = Token(claims=claims)        except ValidationError as e:            raise Unauthorized(f"Token is invalid: {e}")        self.verify_scope(token.scope)        # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'        self.authorize_user(token.user)        return token    def get_key(self, token: str, timeout: float = 1.0) -> jwt.PyJWK:        """Return the JSON Web KEY (JWK) corresponding to kid."""        # NB: pyjwt does not allow timeouts, but we can set it using the        # global value        old_timeout = socket.getdefaulttimeout()        try:            socket.setdefaulttimeout(timeout)            return self.jwk_client.get_signing_key_from_jwt(token)        finally:            socket.setdefaulttimeout(old_timeout)    def verify_token_use(self, claims: Dict[str, Any]) -> None:        """Check the token_use claim."""        if claims["token_use"] != "access":            raise Unauthorized(                f"Token has invalid token_use claim: {claims['token_use']}"            )    def verify_scope(self, claims_scope: FrozenSet[str]) -> None:        """Parse scopes and optionally check scope claim."""        if self.settings.scope is None:            return        if self.settings.scope not in claims_scope:            raise Unauthorized(f"Token is missing '{self.settings.scope}' scope")    def authorize_user(self, user: User) -> None:        if self.settings.admin_users is None:            return        if user.id not in self.settings.admin_users:            raise PermissionDenied()
 |