Forráskód Böngészése

[DONE] Pydantic2 compat (#3)

jpprins1 1 éve
szülő
commit
e584f1c636

+ 12 - 3
CHANGES.md

@@ -1,10 +1,16 @@
 # Changelog of clean-python
 
 
-0.1.3 (unreleased)
-------------------
+0.2.0 (unreleased)
+--------------------
+
+- Pydantic 2.x support. Drops Pydantic 1.x support, use 0.1.x for Pydantic 1.x.
+  See https://docs.pydantic.dev/latest/migration/
+
+- `BadRequest` is a subclass of `Exception` instead of `ValidationError` / `ValueError`.
 
-- Nothing changed yet.
+- `oauth2.OAuth2Settings` is split into two new objects: `TokenVerifierSettings` and
+  `OAuth2SPAClientSettings`. The associated call signature of `Service` was changed.
 
 
 0.1.2 (2023-07-31)
@@ -13,6 +19,9 @@
 - Added py.typed marker.
 
 
+
+
+
 0.1.1 (2023-07-31)
 ------------------
 

+ 1 - 1
clean_python/__init__.py

@@ -3,5 +3,5 @@
 from .base import *  # NOQA
 
 # fmt: off
-__version__ = '0.1.3.dev0'
+__version__ = '0.2.0b4.dev0'
 # fmt: on

+ 2 - 3
clean_python/base/domain/domain_service.py

@@ -1,11 +1,10 @@
 # (c) Nelen & Schuurmans
 
 from pydantic import BaseModel
+from pydantic import ConfigDict
 
 __all__ = ["DomainService"]
 
 
 class DomainService(BaseModel):
-    class Config:
-        allow_mutation = False
-        arbitrary_types_allowed = True
+    model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)

+ 14 - 11
clean_python/base/domain/exceptions.py

@@ -1,12 +1,13 @@
 # (c) Nelen & Schuurmans
 
 from typing import Any
+from typing import Dict
+from typing import List
 from typing import Optional
 from typing import Union
 
 from pydantic import create_model
 from pydantic import ValidationError
-from pydantic.error_wrappers import ErrorWrapper
 
 __all__ = [
     "AlreadyExists",
@@ -54,23 +55,25 @@ class PreconditionFailed(Exception):
 request_model = create_model("Request")
 
 
-class BadRequest(ValidationError):
+class BadRequest(Exception):
     def __init__(self, err_or_msg: Union[ValidationError, str]):
-        if isinstance(err_or_msg, ValidationError):
-            errors = err_or_msg.raw_errors
-        else:
-            errors = [ErrorWrapper(ValueError(err_or_msg), "*")]
-        super().__init__(errors, request_model)
+        self._internal_error = err_or_msg
+        super().__init__(err_or_msg)
+
+    def errors(self) -> List[Dict]:
+        if isinstance(self._internal_error, ValidationError):
+            return self._internal_error.errors()
+        return [{"error": self}]
 
     def __str__(self) -> str:
-        errors = self.errors()
-        if len(errors) == 1:
-            error = errors[0]
+        error = self._internal_error
+        if isinstance(error, ValidationError):
+            error = error.errors()[0]
             loc = "'" + ",".join([str(x) for x in error["loc"]]) + "' "
             if loc == "'*' ":
                 loc = ""
             return f"validation error: {loc}{error['msg']}"
-        return super().__str__()
+        return f"validation error: {super().__str__()}"
 
 
 class Unauthorized(Exception):

+ 1 - 2
clean_python/base/domain/pagination.py

@@ -6,7 +6,6 @@ from typing import Sequence
 from typing import TypeVar
 
 from pydantic import BaseModel
-from pydantic.generics import GenericModel
 
 __all__ = ["Page", "PageOptions"]
 
@@ -20,7 +19,7 @@ class PageOptions(BaseModel):
     ascending: bool = True
 
 
-class Page(GenericModel, Generic[T]):
+class Page(BaseModel, Generic[T]):
     total: int
     items: Sequence[T]
     limit: Optional[int] = None

+ 3 - 3
clean_python/base/domain/repository.py

@@ -68,19 +68,19 @@ class Repository(Generic[T]):
     async def add(self, item: Union[T, Json]) -> T:
         if isinstance(item, dict):
             item = self.entity.create(**item)
-        created = await self.gateway.add(item.dict())
+        created = await self.gateway.add(item.model_dump())
         return self.entity(**created)
 
     async def update(self, id: int, values: Json) -> T:
         if not values:
             return await self.get(id)
         updated = await self.gateway.update_transactional(
-            id, lambda x: self.entity(**x).update(**values).dict()
+            id, lambda x: self.entity(**x).update(**values).model_dump()
         )
         return self.entity(**updated)
 
     async def upsert(self, item: T) -> T:
-        values = item.dict()
+        values = item.model_dump()
         upserted = await self.gateway.upsert(values)
         return self.entity(**upserted)
 

+ 4 - 4
clean_python/base/domain/value_object.py

@@ -5,6 +5,7 @@ from typing import Type
 from typing import TypeVar
 
 from pydantic import BaseModel
+from pydantic import ConfigDict
 from pydantic import ValidationError
 
 from .exceptions import BadRequest
@@ -16,12 +17,11 @@ T = TypeVar("T", bound="ValueObject")
 
 
 class ValueObject(BaseModel):
-    class Config:
-        allow_mutation = False
+    model_config = ConfigDict(frozen=True)
 
     def run_validation(self: T) -> T:
         try:
-            return self.__class__(**self.dict())
+            return self.__class__(**self.model_dump())
         except ValidationError as e:
             raise BadRequest(e)
 
@@ -34,7 +34,7 @@ class ValueObject(BaseModel):
 
     def update(self: T, **values) -> T:
         try:
-            return self.__class__(**{**self.dict(), **values})
+            return self.__class__(**{**self.model_dump(), **values})
         except ValidationError as e:
             raise BadRequest(e)
 

+ 1 - 1
clean_python/base/infrastructure/internal_gateway.py

@@ -47,7 +47,7 @@ class InternalGateway(Generic[E, T]):
 
     async def add(self, item: T) -> T:
         try:
-            created = await self.manage.create(item.dict())
+            created = await self.manage.create(item.model_dump())
         except BadRequest as e:
             raise ValueError(e)
         return self._map(created)

+ 1 - 1
clean_python/celery/celery_rmq_broker.py

@@ -27,7 +27,7 @@ class CeleryHeaders(ValueObject):
     origin: Optional[str] = None
 
     def json_dict(self):
-        return json.loads(self.json())
+        return json.loads(self.model_dump_json())
 
 
 class CeleryRmqBroker(Gateway):

+ 11 - 0
clean_python/fastapi/context.py

@@ -4,12 +4,15 @@ from contextvars import ContextVar
 
 from fastapi import Request
 
+from ..oauth2 import Claims
+
 __all__ = ["ctx", "RequestMiddleware"]
 
 
 class Context:
     def __init__(self):
         self._request_value: ContextVar[Request] = ContextVar("request_value")
+        self._claims_value: ContextVar[Claims] = ContextVar("claims_value")
 
     @property
     def request(self) -> Request:
@@ -19,6 +22,14 @@ class Context:
     def request(self, value: Request) -> None:
         self._request_value.set(value)
 
+    @property
+    def claims(self) -> Claims:
+        return self._claims_value.get()
+
+    @claims.setter
+    def claims(self, value: Claims) -> None:
+        self._claims_value.set(value)
+
 
 ctx = Context()
 

+ 1 - 0
clean_python/fastapi/error_responses.py

@@ -73,6 +73,7 @@ async def unauthorized_handler(request: Request, exc: Unauthorized):
     return JSONResponse(
         status_code=status.HTTP_401_UNAUTHORIZED,
         content={"message": "Unauthorized"},
+        headers={"WWW-Authenticate": "Bearer"},
     )
 
 

+ 5 - 5
clean_python/fastapi/request_query.py

@@ -3,7 +3,7 @@
 from typing import List
 
 from fastapi import Query
-from pydantic import validator
+from pydantic import field_validator
 
 from clean_python import Filter
 from clean_python import PageOptions
@@ -19,11 +19,11 @@ class RequestQuery(ValueObject):
         default="id", enum=["id", "-id"], description="Field to order by"
     )
 
-    @validator("order_by")
-    def validate_order_by_enum(cls, v):
+    @field_validator("order_by")
+    def validate_order_by_enum(cls, v, _):
         # the 'enum' parameter doesn't actually do anthing in validation
         # See: https://github.com/tiangolo/fastapi/issues/2910
-        allowed = cls.__fields__["order_by"].field_info.extra["enum"]
+        allowed = cls.model_fields["order_by"].json_schema_extra["enum"]
         if v not in allowed:
             raise ValueError(f"'order_by' must be one of {allowed}")
         return v
@@ -41,7 +41,7 @@ class RequestQuery(ValueObject):
 
     def filters(self) -> List[Filter]:
         result = []
-        for name in self.__fields__:
+        for name in self.model_fields:
             if name in {"limit", "offset", "order_by"}:
                 continue
             value = getattr(self, name)

+ 49 - 41
clean_python/fastapi/service.py

@@ -3,7 +3,6 @@
 import logging
 from typing import Any
 from typing import Callable
-from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Set
@@ -12,11 +11,8 @@ from asgiref.sync import sync_to_async
 from fastapi import Depends
 from fastapi import FastAPI
 from fastapi import Request
-from fastapi.exceptions import HTTPException
 from fastapi.exceptions import RequestValidationError
 from fastapi.security import OAuth2AuthorizationCodeBearer
-from starlette.status import HTTP_401_UNAUTHORIZED
-from starlette.status import HTTP_403_FORBIDDEN
 from starlette.types import ASGIApp
 
 from clean_python import Conflict
@@ -24,9 +20,11 @@ from clean_python import DoesNotExist
 from clean_python import Gateway
 from clean_python import PermissionDenied
 from clean_python import Unauthorized
-from clean_python.oauth2 import OAuth2AccessTokenVerifier
-from clean_python.oauth2 import OAuth2Settings
+from clean_python.oauth2 import OAuth2SPAClientSettings
+from clean_python.oauth2 import TokenVerifier
+from clean_python.oauth2 import TokenVerifierSettings
 
+from .context import ctx
 from .context import RequestMiddleware
 from .error_responses import BadRequest
 from .error_responses import conflict_handler
@@ -47,7 +45,7 @@ logger = logging.getLogger(__name__)
 __all__ = ["Service"]
 
 
-class OAuth2Dependable(OAuth2AuthorizationCodeBearer):
+class OAuth2WithClientDependable(OAuth2AuthorizationCodeBearer):
     """A fastapi 'dependable' configuring OAuth2.
 
     This does two things:
@@ -55,45 +53,53 @@ class OAuth2Dependable(OAuth2AuthorizationCodeBearer):
     - (through FastAPI magic) add the scheme to the OpenAPI spec
     """
 
-    def __init__(self, scope, settings: OAuth2Settings):
-        self.verifier = sync_to_async(
-            OAuth2AccessTokenVerifier(
-                scope,
-                issuer=settings.issuer,
-                resource_server_id=settings.resource_server_id,
-                algorithms=settings.algorithms,
-                admin_users=settings.admin_users,
-            ),
-            thread_sensitive=False,
-        )
+    def __init__(
+        self, settings: TokenVerifierSettings, client: OAuth2SPAClientSettings
+    ):
+        self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
         super().__init__(
-            authorizationUrl=settings.authorization_url,
-            tokenUrl=settings.token_url,
-            scopes={
-                f"{settings.resource_server_id}*:readwrite": "Full read/write access"
-            },
+            authorizationUrl=str(client.authorization_url),
+            tokenUrl=str(client.token_url),
         )
 
     async def __call__(self, request: Request) -> None:
-        token = await super().__call__(request)
-        try:
-            await self.verifier(token)
-        except Unauthorized:
-            raise HTTPException(status_code=HTTP_401_UNAUTHORIZED)
-        except PermissionDenied:
-            raise HTTPException(status_code=HTTP_403_FORBIDDEN)
+        ctx.claims = await self.verifier(request.headers.get("Authorization"))
+
+
+class OAuth2WithoutClientDependable:
+    """A fastapi 'dependable' configuring OAuth2.
+
+    This does one thing:
+    - Verify the token in each request
+    """
+
+    def __init__(self, settings: TokenVerifierSettings):
+        self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
+
+    async def __call__(self, request: Request) -> None:
+        ctx.claims = await self.verifier(request.headers.get("Authorization"))
 
 
-def fastapi_oauth_kwargs(auth: Optional[OAuth2Settings]) -> Dict:
+def get_auth_kwargs(
+    auth: Optional[TokenVerifierSettings],
+    auth_client: Optional[OAuth2SPAClientSettings],
+) -> None:
     if auth is None:
         return {}
-    return {
-        "dependencies": [Depends(OAuth2Dependable(scope="*:readwrite", settings=auth))],
-        "swagger_ui_init_oauth": {
-            "clientId": auth.client_id,
-            "usePkceWithAuthorizationCodeGrant": True,
-        },
-    }
+    if auth_client is None:
+        return {
+            "dependencies": [Depends(OAuth2WithoutClientDependable(settings=auth))],
+        }
+    else:
+        return {
+            "dependencies": [
+                Depends(OAuth2WithClientDependable(settings=auth, client=auth_client))
+            ],
+            "swagger_ui_init_oauth": {
+                "clientId": auth_client.client_id,
+                "usePkceWithAuthorizationCodeGrant": True,
+            },
+        }
 
 
 async def health_check():
@@ -143,7 +149,8 @@ class Service:
         app = FastAPI(
             version=version.prefix,
             tags=sorted(
-                [x.get_openapi_tag().dict() for x in resources], key=lambda x: x["name"]
+                [x.get_openapi_tag().model_dump() for x in resources],
+                key=lambda x: x["name"],
             ),
             **kwargs,
         )
@@ -171,7 +178,8 @@ class Service:
         title: str,
         description: str,
         hostname: str,
-        auth: Optional[OAuth2Settings] = None,
+        auth: Optional[TokenVerifierSettings] = None,
+        auth_client: Optional[OAuth2SPAClientSettings] = None,
         on_startup: Optional[List[Callable[[], Any]]] = None,
         access_logger_gateway: Optional[Gateway] = None,
     ) -> ASGIApp:
@@ -185,7 +193,7 @@ class Service:
         kwargs = {
             "title": title,
             "description": description,
-            **fastapi_oauth_kwargs(auth),
+            **get_auth_kwargs(auth, auth_client),
         }
         versioned_apps = {
             v: self._create_versioned_app(v, **kwargs) for v in self.versions

+ 1 - 0
clean_python/oauth2/__init__.py

@@ -1 +1,2 @@
+from .claims import *  # NOQA
 from .oauth2 import *  # NOQA

+ 16 - 0
clean_python/oauth2/claims.py

@@ -0,0 +1,16 @@
+from typing import FrozenSet
+from typing import Optional
+
+from clean_python import ValueObject
+
+__all__ = ["Claims"]
+
+
+class Tenant(ValueObject):
+    id: int
+    name: str
+
+
+class Claims(ValueObject):
+    scope: FrozenSet[str]
+    tenant: Optional[Tenant]

+ 60 - 43
clean_python/oauth2/oauth2.py

@@ -1,7 +1,9 @@
 # (c) Nelen & Schuurmans
 
 from typing import Dict
+from typing import FrozenSet
 from typing import List
+from typing import Optional
 
 import jwt
 from jwt import PyJWKClient
@@ -9,23 +11,30 @@ from jwt.exceptions import PyJWTError
 from pydantic import AnyHttpUrl
 from pydantic import BaseModel
 
-from clean_python.base.domain.exceptions import PermissionDenied
-from clean_python.base.domain.exceptions import Unauthorized
+from clean_python import PermissionDenied
+from clean_python import Unauthorized
 
-__all__ = ["OAuth2Settings", "OAuth2AccessTokenVerifier"]
+from .claims import Claims
+from .claims import Tenant
 
+__all__ = ["TokenVerifier", "TokenVerifierSettings", "OAuth2SPAClientSettings"]
 
-class OAuth2Settings(BaseModel):
-    client_id: str
+
+class TokenVerifierSettings(BaseModel):
     issuer: str
-    resource_server_id: str
+    algorithms: List[str] = ["RS256"]
+    # optional additional checks:
+    scope: Optional[str] = None
+    admin_users: Optional[List[str]] = None  # 'sub' whitelist
+
+
+class OAuth2SPAClientSettings(BaseModel):
+    client_id: str
     token_url: AnyHttpUrl
     authorization_url: AnyHttpUrl
-    algorithms: List[str] = ["RS256"]
-    admin_users: List[str]
 
 
-class OAuth2AccessTokenVerifier:
+class TokenVerifier:
     """A class for verifying OAuth2 Access Tokens from AWS Cognito
 
     The verification steps followed are documented here:
@@ -37,22 +46,20 @@ class OAuth2AccessTokenVerifier:
     # allow 2 minutes leeway for verifying token expiry:
     LEEWAY = 120
 
-    def __init__(
-        self,
-        scope: str,
-        issuer: str,
-        resource_server_id: str,
-        algorithms: List[str],
-        admin_users: List[str],
-    ):
-        self.scope = scope
-        self.issuer = issuer
-        self.algorithms = algorithms
-        self.resource_server_id = resource_server_id
-        self.admin_users = admin_users
-        self.jwk_client = PyJWKClient(f"{issuer}/.well-known/jwks.json")
-
-    def __call__(self, token: str) -> Dict:
+    def __init__(self, settings: TokenVerifierSettings):
+        self.settings = settings
+        self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
+
+    def __call__(self, authorization: Optional[str]) -> Claims:
+        # Step 0: retrieve the token from the Authorization header
+        # See https://tools.ietf.org/html/rfc6750#section-2.1,
+        # Bearer is case-sensitive and there is exactly 1 separator after.
+        if authorization is None:
+            raise Unauthorized()
+        token = authorization[7:] if authorization.startswith("Bearer") else None
+        if token is None:
+            raise Unauthorized()
+
         # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
         # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
         try:
@@ -65,8 +72,8 @@ class OAuth2AccessTokenVerifier:
             claims = jwt.decode(
                 token,
                 key.key,
-                algorithms=self.algorithms,
-                issuer=self.issuer,
+                algorithms=self.settings.algorithms,
+                issuer=self.settings.issuer,
                 leeway=self.LEEWAY,
                 options={
                     "require": ["exp", "iss", "sub", "scope", "token_use"],
@@ -77,36 +84,46 @@ class OAuth2AccessTokenVerifier:
             raise Unauthorized()
         # Step 3: Verify additional claims. At this point, we have passed
         # verification, so unverified claims may be used safely.
+        scope = self.parse_scope(claims)
+        tenant = self.parse_tenant(claims)
         self.verify_token_use(claims)
-        self.verify_scope(claims)
-        # Step 4: Authorization: we currently work with a hardcoded
-        # list of users ('sub' claims)
-        self.authorize(claims)
-        return claims
+        self.verify_scope(scope)
+        # Step 4: Authorization: verify 'sub' claim against 'admin_users'
+        self.verify_sub(claims)
+        return Claims(scope=scope, tenant=tenant)
 
     def get_key(self, token) -> jwt.PyJWK:
         """Return the JSON Web KEY (JWK) corresponding to kid."""
         return self.jwk_client.get_signing_key_from_jwt(token)
 
-    def verify_token_use(self, claims):
+    def verify_token_use(self, claims: Dict) -> None:
         """Check the token_use claim."""
         if claims["token_use"] != "access":
             # logger.info("Token has invalid token_use claim: %s", claims["token_use"])
             raise Unauthorized()
 
-    def verify_scope(self, claims):
-        """Check scope claim.
-
-        Cognito includes the resource server id inside the scope, like this:
-
-           raster.lizard.net/*.readwrite
-        """
-        if f"{self.resource_server_id}{self.scope}" not in claims["scope"].split(" "):
+    def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
+        """Parse scopes and optionally check scope claim."""
+        if self.settings.scope is None:
+            return
+        if self.settings.scope not in claims_scope:
             # logger.info("Token has invalid scope claim: %s", claims["scope"])
             raise Unauthorized()
 
-    def authorize(self, claims):
+    def verify_sub(self, claims: Dict) -> None:
         """The subject (sub) claim should be in a hard-coded whitelist."""
-        if claims.get("sub") not in self.admin_users:
+        if self.settings.admin_users is None:
+            return
+        if claims.get("sub") not in self.settings.admin_users:
             # logger.info("User with sub %s is not authorized", claims.get("sub"))
             raise PermissionDenied()
+
+    def parse_scope(self, claims: Dict) -> FrozenSet[str]:
+        return frozenset(claims["scope"].split(" "))
+
+    def parse_tenant(self, claims: Dict) -> Optional[Tenant]:
+        if claims.get("tenant"):
+            tenant = Tenant(id=claims["tenant"], name=claims.get("tenant_name", ""))
+        else:
+            tenant = None
+        return tenant

+ 2 - 1
pyproject.toml

@@ -10,7 +10,7 @@ license = {text = "MIT"}
 classifiers = ["Programming Language :: Python"]
 keywords = []
 requires-python = ">=3.7"
-dependencies = ["pydantic==1.*", "inject==4.*", "asgiref", "blinker"]
+dependencies = ["pydantic==2.*", "inject==4.*", "asgiref", "blinker"]
 dynamic = ["version"]
 
 [project.optional-dependencies]
@@ -19,6 +19,7 @@ test = [
     "pytest-cov",
     "pytest-asyncio",
     "debugpy",
+    "httpx",
 ]
 dramatiq = ["dramatiq"]
 fastapi = ["fastapi"]

+ 19 - 69
tests/test_oauth2.py → tests/oauth2/conftest.py

@@ -1,29 +1,16 @@
-# (c) Nelen & Schuurmans
-
 import time
 from unittest import mock
 
 import jwt
 import pytest
 
-from clean_python import PermissionDenied
-from clean_python import Unauthorized
-from clean_python.oauth2 import OAuth2AccessTokenVerifier
-
-
-@pytest.fixture
-def settings():
-    return {
-        "issuer": "https://cognito-idp.region.amazonaws.com/region_abc123",
-        "resource_server_id": "localhost/",
-        "algorithms": ["RS256"],
-        "admin_users": ["foo"],
-    }
+from clean_python.oauth2 import TokenVerifier
+from clean_python.oauth2 import TokenVerifierSettings
 
 
 @pytest.fixture
 def private_key():
-    # this key was generated especially for this test suite; is has no other applications
+    # this key was generated especially for this test suite; it has no other applications
     return {
         "p": "_PgJBxrGEy8I5KvY_nDRT9loaBPqHHn0AUiTa92zBrAX0qA8ZhV66pUkX2JehU3efduel4FOK2xx-W31p7kCLoaGsMtfKAPYC33KptCH9YXkeMQHq1jWfcRgAVXpdXc7M4pQxO8Dh2BU8qhtAzhpbP4tUPoLIGcTUGd-1ieDkqE",  # NOQA
         "kty": "RSA",
@@ -47,21 +34,18 @@ def public_key(private_key):
 
 
 @pytest.fixture
-def patched_verifier(public_key, settings):
-    verifier = OAuth2AccessTokenVerifier(scope="all", **settings)
-    with mock.patch.object(verifier, "jwk_client") as jwk_client:
-        jwk_client.get_signing_key_from_jwt.return_value = jwt.PyJWK.from_dict(
-            public_key
-        )
-        yield verifier
+def jwk_patched(public_key):
+    with mock.patch.object(TokenVerifier, "get_key") as f:
+        f.return_value = jwt.PyJWK.from_dict(public_key)
+        yield
 
 
 @pytest.fixture
-def token_generator(private_key, settings):
+def token_generator(private_key):
     default_claims = {
         "sub": "foo",
-        "iss": settings["issuer"],
-        "scope": f"{settings['resource_server_id']}all",
+        "iss": "https://some/auth/server",
+        "scope": "user",
         "token_use": "access",
         "exp": int(time.time()) + 3600,
         "iat": int(time.time()) - 3600,
@@ -81,46 +65,12 @@ def token_generator(private_key, settings):
     return generate_token
 
 
-def test_verifier_ok(patched_verifier, token_generator):
-    token = token_generator()
-    verified_claims = patched_verifier(token)
-    assert verified_claims == jwt.decode(token, options={"verify_signature": False})
-
-    patched_verifier.jwk_client.get_signing_key_from_jwt.assert_called_once_with(token)
-
-
-def test_verifier_exp_leeway(patched_verifier, token_generator):
-    token = token_generator(exp=int(time.time()) - 60)
-    patched_verifier(token)
-
-
-def test_verifier_multiple_scopes(patched_verifier, token_generator, settings):
-    token = token_generator(scope=f"scope1 {settings['resource_server_id']}all scope3")
-    patched_verifier(token)
-
-
-@pytest.mark.parametrize(
-    "claim_overrides",
-    [
-        {"iss": "https://authserver"},
-        {"iss": None},
-        {"scope": "nothing"},
-        {"scope": None},
-        {"exp": int(time.time()) - 3600},
-        {"exp": None},
-        {"nbf": int(time.time()) + 3600},
-        {"token_use": "id"},
-        {"token_use": None},
-        {"sub": None},
-    ],
-)
-def test_verifier_bad(patched_verifier, token_generator, claim_overrides):
-    token = token_generator(**claim_overrides)
-    with pytest.raises(Unauthorized):
-        patched_verifier(token)
-
-
-def test_verifier_authorize(patched_verifier, token_generator):
-    token = token_generator(sub="bar")
-    with pytest.raises(PermissionDenied):
-        patched_verifier(token)
+@pytest.fixture
+def settings():
+    # settings match the defaults in the token_generator fixture
+    return TokenVerifierSettings(
+        issuer="https://some/auth/server",
+        scope="user",
+        algorithms=["RS256"],
+        admin_users=["foo"],
+    )

+ 87 - 0
tests/oauth2/test_oauth2.py

@@ -0,0 +1,87 @@
+# (c) Nelen & Schuurmans
+
+import time
+
+import pytest
+
+from clean_python import PermissionDenied
+from clean_python import Unauthorized
+from clean_python.oauth2 import TokenVerifier
+
+
+@pytest.fixture
+def patched_verifier(jwk_patched, settings):
+    return TokenVerifier(settings)
+
+
+def test_verifier_ok(patched_verifier, token_generator):
+    token = token_generator()
+    verified_claims = patched_verifier("Bearer " + token)
+    assert verified_claims.tenant is None
+    assert verified_claims.scope == {"user"}
+
+    patched_verifier.get_key.assert_called_once_with(token)
+
+
+def test_verifier_ok_with_tenant(patched_verifier, token_generator):
+    token = token_generator(tenant="15")
+    verified_claims = patched_verifier("Bearer " + token)
+    assert verified_claims.tenant.id == 15
+    assert verified_claims.tenant.name == ""
+
+
+def test_verifier_ok_with_tenant_and_name(patched_verifier, token_generator):
+    token = token_generator(tenant=15, tenant_name="foo")
+    verified_claims = patched_verifier("Bearer " + token)
+    assert verified_claims.tenant.id == 15
+    assert verified_claims.tenant.name == "foo"
+
+
+def test_verifier_exp_leeway(patched_verifier, token_generator):
+    token = token_generator(exp=int(time.time()) - 60)
+    patched_verifier("Bearer " + token)
+
+
+def test_verifier_multiple_scopes(patched_verifier, token_generator, settings):
+    token = token_generator(scope=f"scope1 {settings.scope} scope3")
+    patched_verifier("Bearer " + token)
+
+
+@pytest.mark.parametrize(
+    "claim_overrides",
+    [
+        {"iss": "https://authserver"},
+        {"iss": None},
+        {"scope": "nothing"},
+        {"scope": None},
+        {"exp": int(time.time()) - 3600},
+        {"exp": None},
+        {"nbf": int(time.time()) + 3600},
+        {"token_use": "id"},
+        {"token_use": None},
+        {"sub": None},
+    ],
+)
+def test_verifier_bad(patched_verifier, token_generator, claim_overrides):
+    token = token_generator(**claim_overrides)
+    with pytest.raises(Unauthorized):
+        patched_verifier("Bearer " + token)
+
+
+def test_verifier_authorize(patched_verifier, token_generator):
+    token = token_generator(sub="bar")
+    with pytest.raises(PermissionDenied):
+        patched_verifier("Bearer " + token)
+
+
+@pytest.mark.parametrize("prefix", ["", "foo ", "key ", "bearer ", "Bearer  "])
+def test_verifier_bad_header_prefix(patched_verifier, token_generator, prefix):
+    token = token_generator()
+    with pytest.raises(Unauthorized):
+        patched_verifier(prefix + token)
+
+
+@pytest.mark.parametrize("header", ["", None, " "])
+def test_verifier_no_header(patched_verifier, header):
+    with pytest.raises(Unauthorized):
+        patched_verifier(header)

+ 89 - 0
tests/oauth2/test_service_auth.py

@@ -0,0 +1,89 @@
+from http import HTTPStatus
+
+import pytest
+from fastapi.testclient import TestClient
+
+from clean_python import InMemoryGateway
+from clean_python.fastapi import get
+from clean_python.fastapi import Resource
+from clean_python.fastapi import Service
+from clean_python.fastapi import v
+from clean_python.oauth2 import OAuth2SPAClientSettings
+from clean_python.oauth2 import TokenVerifierSettings
+
+
+class FooResource(Resource, version=v(1), name="testing"):
+    @get("/foo")
+    def testing(self):
+        return "ok"
+
+
+@pytest.fixture
+def app(settings: TokenVerifierSettings):
+    return Service(FooResource()).create_app(
+        title="test",
+        description="testing",
+        hostname="testserver",
+        auth=settings,
+        access_logger_gateway=InMemoryGateway([]),
+    )
+
+
+@pytest.fixture
+def client(app):
+    return TestClient(app)
+
+
+@pytest.mark.usefixtures("jwk_patched")
+def test_no_header(app, client: TestClient):
+    response = client.get(app.url_path_for("v1/testing"))
+
+    assert response.status_code == HTTPStatus.UNAUTHORIZED
+
+
+@pytest.mark.usefixtures("jwk_patched")
+def test_ok(app, client: TestClient, token_generator):
+    response = client.get(
+        app.url_path_for("v1/testing"),
+        headers={"Authorization": "Bearer " + token_generator()},
+    )
+
+    assert response.status_code == HTTPStatus.OK
+
+
+@pytest.fixture
+def app2(settings: TokenVerifierSettings):
+    return Service(FooResource()).create_app(
+        title="test",
+        description="testing",
+        hostname="testserver",
+        auth=settings,
+        auth_client=OAuth2SPAClientSettings(
+            client_id="123",
+            token_url="https://server/token",
+            authorization_url="https://server/token",
+        ),
+        access_logger_gateway=InMemoryGateway([]),
+    )
+
+
+@pytest.fixture
+def client2(app):
+    return TestClient(app)
+
+
+@pytest.mark.usefixtures("jwk_patched")
+def test_no_header2(app2, client2: TestClient):
+    response = client2.get(app2.url_path_for("v1/testing"))
+
+    assert response.status_code == HTTPStatus.UNAUTHORIZED
+
+
+@pytest.mark.usefixtures("jwk_patched")
+def test_ok2(app2, client2: TestClient, token_generator):
+    response = client2.get(
+        app2.url_path_for("v1/testing"),
+        headers={"Authorization": "Bearer " + token_generator()},
+    )
+
+    assert response.status_code == HTTPStatus.OK

+ 1 - 1
tests/test_exceptions.py

@@ -30,4 +30,4 @@ def test_bad_request_from_validation_error():
     except ValidationError as e:
         err = BadRequest(e)
 
-    assert str(err) == "validation error: 'title' field required"
+    assert str(err) == "validation error: 'title' Field required"

+ 24 - 21
tests/test_repository.py

@@ -1,3 +1,4 @@
+from typing import List
 from unittest import mock
 
 import pytest
@@ -30,8 +31,8 @@ class UserRepository(Repository[User]):
 
 
 @pytest.fixture
-def user_repository(users):
-    return UserRepository(gateway=InMemoryGateway(data=[x.dict() for x in users]))
+def user_repository(users: List[User]):
+    return UserRepository(gateway=InMemoryGateway(data=[x.model_dump() for x in users]))
 
 
 @pytest.fixture
@@ -61,69 +62,71 @@ async def test_all(filter_m, user_repository, page_options):
     filter_m.assert_awaited_once_with([], params=page_options)
 
 
-async def test_add(user_repository):
+async def test_add(user_repository: UserRepository):
     actual = await user_repository.add(User.create(name="d"))
     assert actual.name == "d"
-    assert user_repository.gateway.data[4] == actual.dict()
+    assert user_repository.gateway.data[4] == actual.model_dump()
 
 
-async def test_add_json(user_repository):
+async def test_add_json(user_repository: UserRepository):
     actual = await user_repository.add({"name": "d"})
     assert actual.name == "d"
-    assert user_repository.gateway.data[4] == actual.dict()
+    assert user_repository.gateway.data[4] == actual.model_dump()
 
 
-async def test_add_json_validates(user_repository):
+async def test_add_json_validates(user_repository: UserRepository):
     with pytest.raises(BadRequest):
         await user_repository.add({"id": "d"})
 
 
-async def test_update(user_repository):
+async def test_update(user_repository: UserRepository):
     actual = await user_repository.update(id=2, values={"name": "d"})
     assert actual.name == "d"
-    assert user_repository.gateway.data[2] == actual.dict()
+    assert user_repository.gateway.data[2] == actual.model_dump()
 
 
-async def test_update_does_not_exist(user_repository):
+async def test_update_does_not_exist(user_repository: UserRepository):
     with pytest.raises(DoesNotExist):
         await user_repository.update(id=4, values={"name": "d"})
 
 
-async def test_update_validates(user_repository):
+async def test_update_validates(user_repository: UserRepository):
     with pytest.raises(BadRequest):
         await user_repository.update(id=2, values={"id": 6})
 
 
-async def test_remove(user_repository):
+async def test_remove(user_repository: UserRepository):
     assert await user_repository.remove(2)
     assert 2 not in user_repository.gateway.data
 
 
-async def test_remove_does_not_exist(user_repository):
+async def test_remove_does_not_exist(user_repository: UserRepository):
     assert not await user_repository.remove(4)
 
 
-async def test_upsert_updates(user_repository):
+async def test_upsert_updates(user_repository: UserRepository):
     actual = await user_repository.upsert(User.create(id=2, name="d"))
     assert actual.name == "d"
-    assert user_repository.gateway.data[2] == actual.dict()
+    assert user_repository.gateway.data[2] == actual.model_dump()
 
 
-async def test_upsert_adds(user_repository):
+async def test_upsert_adds(user_repository: UserRepository):
     actual = await user_repository.upsert(User.create(id=4, name="d"))
     assert actual.name == "d"
-    assert user_repository.gateway.data[4] == actual.dict()
+    assert user_repository.gateway.data[4] == actual.model_dump()
 
 
 @mock.patch.object(InMemoryGateway, "count")
-async def test_filter(count_m, user_repository, users):
+async def test_filter(count_m, user_repository: UserRepository, users):
     actual = await user_repository.filter([Filter(field="name", values=["b"])])
     assert actual == Page(total=1, items=[users[1]], limit=None, offest=None)
     assert not count_m.called
 
 
 @mock.patch.object(InMemoryGateway, "count")
-async def test_filter_with_pagination(count_m, user_repository, users, page_options):
+async def test_filter_with_pagination(
+    count_m, user_repository: UserRepository, users, page_options
+):
     actual = await user_repository.filter(
         [Filter(field="name", values=["b"])], page_options
     )
@@ -142,7 +145,7 @@ async def test_filter_with_pagination(count_m, user_repository, users, page_opti
 )
 @mock.patch.object(InMemoryGateway, "count")
 async def test_filter_with_pagination_calls_count(
-    count_m, user_repository, users, page_options
+    count_m, user_repository: UserRepository, users, page_options
 ):
     count_m.return_value = 123
     actual = await user_repository.filter([], page_options)
@@ -156,7 +159,7 @@ async def test_filter_with_pagination_calls_count(
 
 
 @mock.patch.object(Repository, "filter")
-async def test_by(filter_m, user_repository, page_options):
+async def test_by(filter_m, user_repository: UserRepository, page_options):
     filter_m.return_value = Page(total=0, items=[])
     assert await user_repository.by("name", "b", page_options) is filter_m.return_value
 

+ 2 - 1
tests/test_root_entity.py

@@ -5,6 +5,7 @@ from unittest import mock
 import pytest
 
 from clean_python import RootEntity
+from clean_python.base.domain.exceptions import BadRequest
 
 SOME_DATETIME = datetime(2023, 1, 1, tzinfo=timezone.utc)
 
@@ -64,7 +65,7 @@ def test_update_including_id(user):
 
 @pytest.mark.parametrize("new_id", [None, 42, "foo"])
 def test_update_with_wrong_id(user, new_id):
-    with pytest.raises(ValueError):
+    with pytest.raises(BadRequest):
         user.update(id=new_id, name="piet")
 
 

+ 4 - 4
tests/test_value_object.py

@@ -1,6 +1,6 @@
 import pytest
+from pydantic import field_validator
 from pydantic import ValidationError
-from pydantic import validator
 
 from clean_python import BadRequest
 from clean_python import ValueObject
@@ -9,8 +9,8 @@ from clean_python import ValueObject
 class Color(ValueObject):
     name: str
 
-    @validator("name")
-    def name_not_empty(cls, v):
+    @field_validator("name")
+    def name_not_empty(cls, v, _):
         assert v != ""
         return v
 
@@ -49,7 +49,7 @@ def test_run_validation(color):
 
 
 def test_run_validation_err():
-    color = Color.construct(name="")
+    color = Color.model_construct(name="")
 
     with pytest.raises(BadRequest):
         color.run_validation()