|
@@ -1,6 +1,7 @@
|
|
|
# (c) Nelen & Schuurmans
|
|
|
|
|
|
import logging
|
|
|
+import socket
|
|
|
from typing import Any
|
|
|
from typing import Dict
|
|
|
from typing import FrozenSet
|
|
@@ -37,6 +38,7 @@ class TokenVerifierSettings(BaseModel):
|
|
|
# optional additional checks:
|
|
|
scope: Optional[str] = None
|
|
|
admin_users: Optional[List[str]] = None # 'sub' whitelist
|
|
|
+ jwks_timeout: float = 1.0
|
|
|
|
|
|
|
|
|
class OAuth2SPAClientSettings(BaseModel):
|
|
@@ -89,20 +91,17 @@ class TokenVerifier(BaseTokenVerifier):
|
|
|
# See https://tools.ietf.org/html/rfc6750#section-2.1,
|
|
|
# Bearer is case-sensitive and there is exactly 1 separator after.
|
|
|
if authorization is None:
|
|
|
- logger.info("Missing Authorization header")
|
|
|
- raise Unauthorized()
|
|
|
+ raise Unauthorized("Missing Authorization header")
|
|
|
jwt_str = authorization[7:] if authorization.startswith("Bearer") else None
|
|
|
if jwt_str is None:
|
|
|
- logger.info("Authorization does not start with 'Bearer '")
|
|
|
- raise Unauthorized()
|
|
|
+ raise Unauthorized("Authorization does not start with 'Bearer '")
|
|
|
|
|
|
# Step 1: Confirm the structure of the JWT. This check is part of get_kid since
|
|
|
# jwt.get_unverified_header will raise a JWTError if the structure is wrong.
|
|
|
try:
|
|
|
- key = self.get_key(jwt_str) # JSON Web Key
|
|
|
+ key = self.get_key(jwt_str, self.settings.jwks_timeout) # JSON Web Key
|
|
|
except PyJWTError as e:
|
|
|
- logger.info("Token is invalid: %s", e)
|
|
|
- raise Unauthorized()
|
|
|
+ raise Unauthorized(f"Token is invalid: {e}")
|
|
|
# Step 2: Validate the JWT signature and standard claims
|
|
|
try:
|
|
|
claims = jwt.decode(
|
|
@@ -116,38 +115,43 @@ class TokenVerifier(BaseTokenVerifier):
|
|
|
},
|
|
|
)
|
|
|
except PyJWTError as e:
|
|
|
- logger.info("Token is invalid: %s", e)
|
|
|
- raise Unauthorized()
|
|
|
+ raise Unauthorized(f"Token is invalid: {e}")
|
|
|
# Step 3: Verify additional claims. At this point, we have passed
|
|
|
# verification, so unverified claims may be used safely.
|
|
|
self.verify_token_use(claims)
|
|
|
try:
|
|
|
token = Token(claims=claims)
|
|
|
except ValidationError as e:
|
|
|
- logger.info("Token is invalid: %s", e)
|
|
|
- raise Unauthorized()
|
|
|
+ raise Unauthorized(f"Token is invalid: {e}")
|
|
|
self.verify_scope(token.scope)
|
|
|
# Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
|
|
|
self.authorize_user(token.user)
|
|
|
return token
|
|
|
|
|
|
- def get_key(self, token) -> jwt.PyJWK:
|
|
|
+ def get_key(self, token: str, timeout: float = 1.0) -> jwt.PyJWK:
|
|
|
"""Return the JSON Web KEY (JWK) corresponding to kid."""
|
|
|
- return self.jwk_client.get_signing_key_from_jwt(token)
|
|
|
+ # NB: pyjwt does not allow timeouts, but we can set it using the
|
|
|
+ # global value
|
|
|
+ old_timeout = socket.getdefaulttimeout()
|
|
|
+ try:
|
|
|
+ socket.setdefaulttimeout(timeout)
|
|
|
+ return self.jwk_client.get_signing_key_from_jwt(token)
|
|
|
+ finally:
|
|
|
+ socket.setdefaulttimeout(old_timeout)
|
|
|
|
|
|
def verify_token_use(self, claims: Dict[str, Any]) -> None:
|
|
|
"""Check the token_use claim."""
|
|
|
if claims["token_use"] != "access":
|
|
|
- logger.info("Token has invalid token_use claim: %s", claims["token_use"])
|
|
|
- raise Unauthorized()
|
|
|
+ raise Unauthorized(
|
|
|
+ f"Token has invalid token_use claim: {claims['token_use']}"
|
|
|
+ )
|
|
|
|
|
|
def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
|
|
|
"""Parse scopes and optionally check scope claim."""
|
|
|
if self.settings.scope is None:
|
|
|
return
|
|
|
if self.settings.scope not in claims_scope:
|
|
|
- logger.info("Token is missing '%s' scope", self.settings.scope)
|
|
|
- raise Unauthorized()
|
|
|
+ raise Unauthorized(f"Token is missing '{self.settings.scope}' scope")
|
|
|
|
|
|
def authorize_user(self, user: User) -> None:
|
|
|
if self.settings.admin_users is None:
|