|  | @@ -1,6 +1,7 @@
 | 
	
		
			
				|  |  |  # (c) Nelen & Schuurmans
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  | +import socket
 | 
	
		
			
				|  |  |  from typing import Any
 | 
	
		
			
				|  |  |  from typing import Dict
 | 
	
		
			
				|  |  |  from typing import FrozenSet
 | 
	
	
		
			
				|  | @@ -37,6 +38,7 @@ class TokenVerifierSettings(BaseModel):
 | 
	
		
			
				|  |  |      # optional additional checks:
 | 
	
		
			
				|  |  |      scope: Optional[str] = None
 | 
	
		
			
				|  |  |      admin_users: Optional[List[str]] = None  # 'sub' whitelist
 | 
	
		
			
				|  |  | +    jwks_timeout: float = 1.0
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class OAuth2SPAClientSettings(BaseModel):
 | 
	
	
		
			
				|  | @@ -89,20 +91,17 @@ class TokenVerifier(BaseTokenVerifier):
 | 
	
		
			
				|  |  |          # See https://tools.ietf.org/html/rfc6750#section-2.1,
 | 
	
		
			
				|  |  |          # Bearer is case-sensitive and there is exactly 1 separator after.
 | 
	
		
			
				|  |  |          if authorization is None:
 | 
	
		
			
				|  |  | -            logger.info("Missing Authorization header")
 | 
	
		
			
				|  |  | -            raise Unauthorized()
 | 
	
		
			
				|  |  | +            raise Unauthorized("Missing Authorization header")
 | 
	
		
			
				|  |  |          jwt_str = authorization[7:] if authorization.startswith("Bearer") else None
 | 
	
		
			
				|  |  |          if jwt_str is None:
 | 
	
		
			
				|  |  | -            logger.info("Authorization does not start with 'Bearer '")
 | 
	
		
			
				|  |  | -            raise Unauthorized()
 | 
	
		
			
				|  |  | +            raise Unauthorized("Authorization does not start with 'Bearer '")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
 | 
	
		
			
				|  |  |          # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            key = self.get_key(jwt_str)  # JSON Web Key
 | 
	
		
			
				|  |  | +            key = self.get_key(jwt_str, self.settings.jwks_timeout)  # JSON Web Key
 | 
	
		
			
				|  |  |          except PyJWTError as e:
 | 
	
		
			
				|  |  | -            logger.info("Token is invalid: %s", e)
 | 
	
		
			
				|  |  | -            raise Unauthorized()
 | 
	
		
			
				|  |  | +            raise Unauthorized(f"Token is invalid: {e}")
 | 
	
		
			
				|  |  |          # Step 2: Validate the JWT signature and standard claims
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  |              claims = jwt.decode(
 | 
	
	
		
			
				|  | @@ -116,38 +115,43 @@ class TokenVerifier(BaseTokenVerifier):
 | 
	
		
			
				|  |  |                  },
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |          except PyJWTError as e:
 | 
	
		
			
				|  |  | -            logger.info("Token is invalid: %s", e)
 | 
	
		
			
				|  |  | -            raise Unauthorized()
 | 
	
		
			
				|  |  | +            raise Unauthorized(f"Token is invalid: {e}")
 | 
	
		
			
				|  |  |          # Step 3: Verify additional claims. At this point, we have passed
 | 
	
		
			
				|  |  |          # verification, so unverified claims may be used safely.
 | 
	
		
			
				|  |  |          self.verify_token_use(claims)
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  |              token = Token(claims=claims)
 | 
	
		
			
				|  |  |          except ValidationError as e:
 | 
	
		
			
				|  |  | -            logger.info("Token is invalid: %s", e)
 | 
	
		
			
				|  |  | -            raise Unauthorized()
 | 
	
		
			
				|  |  | +            raise Unauthorized(f"Token is invalid: {e}")
 | 
	
		
			
				|  |  |          self.verify_scope(token.scope)
 | 
	
		
			
				|  |  |          # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
 | 
	
		
			
				|  |  |          self.authorize_user(token.user)
 | 
	
		
			
				|  |  |          return token
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def get_key(self, token) -> jwt.PyJWK:
 | 
	
		
			
				|  |  | +    def get_key(self, token: str, timeout: float = 1.0) -> jwt.PyJWK:
 | 
	
		
			
				|  |  |          """Return the JSON Web KEY (JWK) corresponding to kid."""
 | 
	
		
			
				|  |  | -        return self.jwk_client.get_signing_key_from_jwt(token)
 | 
	
		
			
				|  |  | +        # NB: pyjwt does not allow timeouts, but we can set it using the
 | 
	
		
			
				|  |  | +        # global value
 | 
	
		
			
				|  |  | +        old_timeout = socket.getdefaulttimeout()
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            socket.setdefaulttimeout(timeout)
 | 
	
		
			
				|  |  | +            return self.jwk_client.get_signing_key_from_jwt(token)
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            socket.setdefaulttimeout(old_timeout)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def verify_token_use(self, claims: Dict[str, Any]) -> None:
 | 
	
		
			
				|  |  |          """Check the token_use claim."""
 | 
	
		
			
				|  |  |          if claims["token_use"] != "access":
 | 
	
		
			
				|  |  | -            logger.info("Token has invalid token_use claim: %s", claims["token_use"])
 | 
	
		
			
				|  |  | -            raise Unauthorized()
 | 
	
		
			
				|  |  | +            raise Unauthorized(
 | 
	
		
			
				|  |  | +                f"Token has invalid token_use claim: {claims['token_use']}"
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
 | 
	
		
			
				|  |  |          """Parse scopes and optionally check scope claim."""
 | 
	
		
			
				|  |  |          if self.settings.scope is None:
 | 
	
		
			
				|  |  |              return
 | 
	
		
			
				|  |  |          if self.settings.scope not in claims_scope:
 | 
	
		
			
				|  |  | -            logger.info("Token is missing '%s' scope", self.settings.scope)
 | 
	
		
			
				|  |  | -            raise Unauthorized()
 | 
	
		
			
				|  |  | +            raise Unauthorized(f"Token is missing '{self.settings.scope}' scope")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def authorize_user(self, user: User) -> None:
 | 
	
		
			
				|  |  |          if self.settings.admin_users is None:
 |