Browse Source

Use a timeout when fetching JWKS (#46)

Casper van der Wel 1 year ago
parent
commit
c55f88d8cb

+ 1 - 1
CHANGES.md

@@ -4,7 +4,7 @@
 0.9.4 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Use a timeout for fetching jwks in TokenVerifier.
 
 
 0.9.3 (2023-12-04)

+ 5 - 0
clean_python/fastapi/error_responses.py

@@ -1,5 +1,6 @@
 # (c) Nelen & Schuurmans
 
+import logging
 from typing import List
 from typing import Optional
 from typing import Union
@@ -16,6 +17,8 @@ from clean_python import PermissionDenied
 from clean_python import Unauthorized
 from clean_python import ValueObject
 
+logger = logging.getLogger(__name__)
+
 __all__ = [
     "ValidationErrorResponse",
     "DefaultErrorResponse",
@@ -72,6 +75,8 @@ async def validation_error_handler(request: Request, exc: BadRequest) -> JSONRes
 
 
 async def unauthorized_handler(request: Request, exc: Unauthorized) -> JSONResponse:
+    if exc.args:
+        logger.info(f"unauthorized: {exc}")
     return JSONResponse(
         status_code=status.HTTP_401_UNAUTHORIZED,
         content={"message": "Unauthorized", "detail": None},

+ 21 - 17
clean_python/oauth2/token_verifier.py

@@ -1,6 +1,7 @@
 # (c) Nelen & Schuurmans
 
 import logging
+import socket
 from typing import Any
 from typing import Dict
 from typing import FrozenSet
@@ -37,6 +38,7 @@ class TokenVerifierSettings(BaseModel):
     # optional additional checks:
     scope: Optional[str] = None
     admin_users: Optional[List[str]] = None  # 'sub' whitelist
+    jwks_timeout: float = 1.0
 
 
 class OAuth2SPAClientSettings(BaseModel):
@@ -89,20 +91,17 @@ class TokenVerifier(BaseTokenVerifier):
         # See https://tools.ietf.org/html/rfc6750#section-2.1,
         # Bearer is case-sensitive and there is exactly 1 separator after.
         if authorization is None:
-            logger.info("Missing Authorization header")
-            raise Unauthorized()
+            raise Unauthorized("Missing Authorization header")
         jwt_str = authorization[7:] if authorization.startswith("Bearer") else None
         if jwt_str is None:
-            logger.info("Authorization does not start with 'Bearer '")
-            raise Unauthorized()
+            raise Unauthorized("Authorization does not start with 'Bearer '")
 
         # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
         # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
         try:
-            key = self.get_key(jwt_str)  # JSON Web Key
+            key = self.get_key(jwt_str, self.settings.jwks_timeout)  # JSON Web Key
         except PyJWTError as e:
-            logger.info("Token is invalid: %s", e)
-            raise Unauthorized()
+            raise Unauthorized(f"Token is invalid: {e}")
         # Step 2: Validate the JWT signature and standard claims
         try:
             claims = jwt.decode(
@@ -116,38 +115,43 @@ class TokenVerifier(BaseTokenVerifier):
                 },
             )
         except PyJWTError as e:
-            logger.info("Token is invalid: %s", e)
-            raise Unauthorized()
+            raise Unauthorized(f"Token is invalid: {e}")
         # Step 3: Verify additional claims. At this point, we have passed
         # verification, so unverified claims may be used safely.
         self.verify_token_use(claims)
         try:
             token = Token(claims=claims)
         except ValidationError as e:
-            logger.info("Token is invalid: %s", e)
-            raise Unauthorized()
+            raise Unauthorized(f"Token is invalid: {e}")
         self.verify_scope(token.scope)
         # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
         self.authorize_user(token.user)
         return token
 
-    def get_key(self, token) -> jwt.PyJWK:
+    def get_key(self, token: str, timeout: float = 1.0) -> jwt.PyJWK:
         """Return the JSON Web KEY (JWK) corresponding to kid."""
-        return self.jwk_client.get_signing_key_from_jwt(token)
+        # NB: pyjwt does not allow timeouts, but we can set it using the
+        # global value
+        old_timeout = socket.getdefaulttimeout()
+        try:
+            socket.setdefaulttimeout(timeout)
+            return self.jwk_client.get_signing_key_from_jwt(token)
+        finally:
+            socket.setdefaulttimeout(old_timeout)
 
     def verify_token_use(self, claims: Dict[str, Any]) -> None:
         """Check the token_use claim."""
         if claims["token_use"] != "access":
-            logger.info("Token has invalid token_use claim: %s", claims["token_use"])
-            raise Unauthorized()
+            raise Unauthorized(
+                f"Token has invalid token_use claim: {claims['token_use']}"
+            )
 
     def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
         """Parse scopes and optionally check scope claim."""
         if self.settings.scope is None:
             return
         if self.settings.scope not in claims_scope:
-            logger.info("Token is missing '%s' scope", self.settings.scope)
-            raise Unauthorized()
+            raise Unauthorized(f"Token is missing '{self.settings.scope}' scope")
 
     def authorize_user(self, user: User) -> None:
         if self.settings.admin_users is None:

+ 11 - 2
tests/fastapi/test_error_responses.py

@@ -1,4 +1,5 @@
 import json
+import logging
 from http import HTTPStatus
 
 from pydantic import BaseModel
@@ -44,15 +45,19 @@ async def test_conflict_no_msg():
     assert json.loads(actual.body) == {"message": "Conflict", "detail": None}
 
 
-async def test_unauthorized():
+async def test_unauthorized(caplog):
     actual = await unauthorized_handler(None, Unauthorized())
 
     assert actual.status_code == HTTPStatus.UNAUTHORIZED
     assert json.loads(actual.body) == {"message": "Unauthorized", "detail": None}
     assert actual.headers["WWW-Authenticate"] == "Bearer"
 
+    assert caplog.record_tuples == []
+
+
+async def test_unauthorized_wit_msg(caplog):
+    caplog.set_level(logging.INFO)
 
-async def test_unauthorized_with_msg():
     # message should be ignored
     actual = await unauthorized_handler(None, Unauthorized("foo"))
 
@@ -60,6 +65,10 @@ async def test_unauthorized_with_msg():
     assert json.loads(actual.body) == {"message": "Unauthorized", "detail": None}
     assert actual.headers["WWW-Authenticate"] == "Bearer"
 
+    assert caplog.record_tuples == [
+        ("clean_python.fastapi.error_responses", logging.INFO, "unauthorized: foo")
+    ]
+
 
 async def test_permission_denied():
     actual = await permission_denied_handler(None, PermissionDenied())

+ 8 - 4
tests/oauth2/conftest.py

@@ -1,10 +1,12 @@
+import json
 import time
+import urllib.request
+from io import BytesIO
 from unittest import mock
 
 import jwt
 import pytest
 
-from clean_python.oauth2 import TokenVerifier
 from clean_python.oauth2 import TokenVerifierSettings
 
 
@@ -35,9 +37,11 @@ def public_key(private_key):
 
 @pytest.fixture
 def jwk_patched(public_key):
-    with mock.patch.object(TokenVerifier, "get_key") as f:
-        f.return_value = jwt.PyJWK.from_dict(public_key)
-        yield
+    with mock.patch.object(urllib.request, "urlopen") as urlopen:
+        urlopen.return_value.__enter__.return_value = BytesIO(
+            json.dumps({"keys": [public_key]}).encode()
+        )
+        yield urlopen
 
 
 @pytest.fixture

+ 37 - 3
tests/oauth2/test_verifier.py

@@ -1,7 +1,13 @@
 # (c) Nelen & Schuurmans
 
+import json
+import socket
 import time
+import urllib.request
+from io import BytesIO
+from unittest import mock
 
+import jwt
 import pytest
 
 from clean_python import PermissionDenied
@@ -11,11 +17,11 @@ from clean_python.oauth2 import TokenVerifier
 
 
 @pytest.fixture
-def patched_verifier(jwk_patched, settings):
+def patched_verifier(settings, jwk_patched):
     return TokenVerifier(settings)
 
 
-def test_verifier_ok(patched_verifier, token_generator):
+def test_verifier_ok(patched_verifier, token_generator, jwk_patched):
     token = token_generator()
     verified_token = patched_verifier("Bearer " + token)
 
@@ -24,7 +30,9 @@ def test_verifier_ok(patched_verifier, token_generator):
     assert verified_token.tenant is None
     assert verified_token.scope == {"user"}
 
-    patched_verifier.get_key.assert_called_once_with(token)
+    jwk_patched.assert_called_once_with(
+        "https://some/auth/server/.well-known/jwks.json"
+    )
 
 
 def test_verifier_exp_leeway(patched_verifier, token_generator):
@@ -76,3 +84,29 @@ def test_verifier_bad_header_prefix(patched_verifier, token_generator, prefix):
 def test_verifier_no_header(patched_verifier, header):
     with pytest.raises(Unauthorized):
         patched_verifier(header)
+
+
+@mock.patch.object(urllib.request, "urlopen")
+def test_get_key_timeout(urlopen, patched_verifier, token_generator, public_key):
+    def side_effect():
+        assert socket.getdefaulttimeout() == 0.1
+        return BytesIO(json.dumps({"keys": [public_key]}).encode())
+
+    urlopen.return_value.__enter__.side_effect = side_effect
+
+    assert socket.getdefaulttimeout() is None
+    key = patched_verifier.get_key(token_generator(), timeout=0.1)
+    assert socket.getdefaulttimeout() is None
+
+    assert isinstance(key, jwt.PyJWK)
+    assert key.key_id == public_key["kid"]
+
+
+@mock.patch.object(urllib.request, "urlopen")
+def test_get_key_invalid_kid(urlopen, settings, token_generator, public_key):
+    urlopen.return_value.__enter__.return_value = BytesIO(
+        json.dumps({"keys": []}).encode()
+    )
+
+    with pytest.raises(jwt.exceptions.PyJWTError):
+        TokenVerifier(settings).get_key(token_generator())