Bladeren bron

Add optional scope check to resource methods (#6)

Casper van der Wel 1 jaar geleden
bovenliggende
commit
be95f47eb8

+ 7 - 1
CHANGES.md

@@ -4,7 +4,13 @@
 0.2.3 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Add `scope` kwarg to http_method decorators (get, post, etc.)
+
+- Moved the `Context` (`ctx`) to `clean_python.base` and changed its attributes to
+  `path`, `user` and `tenant`.
+
+- The `SQLGateway` can now be constructed with `multitenant=True` which makes it
+  automatically filter the `tenant` column with the current `ctx.tenant`.
 
 
 0.2.2 (2023-08-03)

+ 1 - 0
clean_python/base/domain/__init__.py

@@ -1,3 +1,4 @@
+from .context import *  # NOQA
 from .domain_event import *  # NOQA
 from .domain_service import *  # NOQA
 from .exceptions import *  # NOQA

+ 79 - 0
clean_python/base/domain/context.py

@@ -0,0 +1,79 @@
+# (c) Nelen & Schuurmans
+
+import os
+from contextvars import ContextVar
+from typing import FrozenSet
+from typing import Optional
+
+from pydantic import AnyUrl
+from pydantic import FileUrl
+
+from .value_object import ValueObject
+
+__all__ = ["ctx", "User", "Tenant", "Scope"]
+
+
+class User(ValueObject):
+    id: str
+    name: str
+
+
+Scope = FrozenSet[str]
+
+
+class Tenant(ValueObject):
+    id: int
+    name: str
+
+
+class Context:
+    """Provide global access to some contextual properties.
+
+    The implementation makes use of python's contextvars, which automatically integrates
+    with asyncio tasks (so that each task runs in its own context). This makes sure that
+    every request-response cycle is isolated.
+    """
+
+    def __init__(self):
+        self._path_value: ContextVar[AnyUrl] = ContextVar(
+            "path_value",
+            default=FileUrl.build(scheme="file", host="/", path=os.getcwd()),
+        )
+        self._user_value: ContextVar[User] = ContextVar(
+            "user_value", default=User(id="ANONYMOUS", name="anonymous")
+        )
+        self._tenant_value: ContextVar[Optional[Tenant]] = ContextVar(
+            "tenant_value", default=None
+        )
+
+    def reset(self):
+        self._path_value.reset()
+        self._user_value.reset()
+        self._tenant_value.reset()
+
+    @property
+    def path(self) -> AnyUrl:
+        return self._path_value.get()
+
+    @path.setter
+    def path(self, value: AnyUrl) -> None:
+        self._path_value.set(value)
+
+    @property
+    def user(self) -> User:
+        return self._user_value.get()
+
+    @user.setter
+    def user(self, value: User) -> None:
+        self._user_value.set(value)
+
+    @property
+    def tenant(self) -> Optional[Tenant]:
+        return self._tenant_value.get()
+
+    @tenant.setter
+    def tenant(self, value: Optional[Tenant]) -> None:
+        self._tenant_value.set(value)
+
+
+ctx = Context()

+ 1 - 1
clean_python/fastapi/__init__.py

@@ -1,6 +1,6 @@
-from .context import *  # NOQA
 from .error_responses import *  # NOQA
 from .fastapi_access_logger import *  # NOQA
 from .request_query import *  # NOQA
 from .resource import *  # NOQA
+from .security import *  # NOQA
 from .service import *  # NOQA

+ 0 - 53
clean_python/fastapi/context.py

@@ -1,53 +0,0 @@
-# (c) Nelen & Schuurmans
-
-from contextvars import ContextVar
-
-from fastapi import Request
-
-from ..oauth2 import Claims
-
-__all__ = ["ctx", "RequestMiddleware"]
-
-
-class Context:
-    def __init__(self):
-        self._request_value: ContextVar[Request] = ContextVar("request_value")
-        self._claims_value: ContextVar[Claims] = ContextVar("claims_value")
-
-    @property
-    def request(self) -> Request:
-        return self._request_value.get()
-
-    @request.setter
-    def request(self, value: Request) -> None:
-        self._request_value.set(value)
-
-    @property
-    def claims(self) -> Claims:
-        return self._claims_value.get()
-
-    @claims.setter
-    def claims(self, value: Claims) -> None:
-        self._claims_value.set(value)
-
-
-ctx = Context()
-
-
-class RequestMiddleware:
-    """Save the current request in a context variable.
-
-    We were experiencing database connections piling up until PostgreSQL's
-    max_connections was hit, which has to do with BaseHTTPMiddleware not
-    interacting properly with context variables. For more details, see:
-    https://github.com/tiangolo/fastapi/issues/4719. Writing this
-    middleware as generic ASGI middleware fixes the problem.
-    """
-
-    def __init__(self, app):
-        self.app = app
-
-    async def __call__(self, scope, receive, send):
-        if scope["type"] == "http":
-            ctx.request = Request(scope, receive)
-        await self.app(scope, receive, send)

+ 1 - 1
clean_python/fastapi/error_responses.py

@@ -80,5 +80,5 @@ async def unauthorized_handler(request: Request, exc: Unauthorized):
 async def permission_denied_handler(request: Request, exc: PermissionDenied):
     return JSONResponse(
         status_code=status.HTTP_403_FORBIDDEN,
-        content={"message": "Permission denied"},
+        content={"message": "Permission denied", "detail": str(exc)},
     )

+ 10 - 3
clean_python/fastapi/resource.py

@@ -10,10 +10,13 @@ from typing import Optional
 from typing import Sequence
 from typing import Type
 
+from fastapi import Depends
 from fastapi.routing import APIRouter
 
 from clean_python import ValueObject
 
+from .security import RequiresScope
+
 __all__ = [
     "Resource",
     "get",
@@ -71,12 +74,12 @@ class APIVersion(ValueObject):
         return APIVersion(version=self.version, stability=self.stability.decrease())
 
 
-def http_method(path: str, **route_options):
+def http_method(path: str, scope: Optional[str] = None, **route_options):
     def wrapper(unbound_method: Callable[..., Any]):
         setattr(
             unbound_method,
             "http_method",
-            (path, route_options),
+            (path, scope, route_options),
         )
         return unbound_method
 
@@ -160,7 +163,7 @@ class Resource:
         router = APIRouter()
         operation_ids = set()
         for endpoint in self._endpoints():
-            path, route_options = endpoint.http_method
+            path, scope, route_options = endpoint.http_method
             operation_id = endpoint.__name__
             if operation_id in operation_ids:
                 raise RuntimeError(
@@ -170,6 +173,10 @@ class Resource:
             # The 'name' is used for reverse lookups (request.path_for): include the
             # version prefix so that we can uniquely refer to an operation.
             name = version.prefix + "/" + endpoint.__name__
+            # 'scope' is implemented using FastAPI's dependency injection system
+            if scope is not None:
+                route_options.setdefault("dependencies", [])
+                route_options["dependencies"].append(Depends(RequiresScope(scope)))
             router.add_api_route(
                 path,
                 endpoint,

+ 82 - 0
clean_python/fastapi/security.py

@@ -0,0 +1,82 @@
+from typing import Optional
+
+from fastapi import Depends
+from fastapi import Request
+from fastapi.security import HTTPBearer
+from fastapi.security import OAuth2AuthorizationCodeBearer
+
+from clean_python import PermissionDenied
+from clean_python.oauth2 import BaseTokenVerifier
+from clean_python.oauth2 import NoAuthTokenVerifier
+from clean_python.oauth2 import OAuth2SPAClientSettings
+from clean_python.oauth2 import Token
+from clean_python.oauth2 import TokenVerifier
+from clean_python.oauth2 import TokenVerifierSettings
+
+__all__ = ["get_token", "RequiresScope"]
+
+verifier: Optional[BaseTokenVerifier] = None
+
+
+def clear_verifier() -> None:
+    global verifier
+
+    verifier = None
+
+
+def set_verifier(settings: Optional[TokenVerifierSettings]) -> None:
+    global verifier
+
+    if settings is None:
+        verifier = NoAuthTokenVerifier()
+    else:
+        verifier = TokenVerifier(settings=settings)
+
+
+def get_token(request: Request) -> Token:
+    """A fastapi 'dependable' yielding the validated token"""
+    global verifier
+
+    assert verifier is not None
+    return verifier(request.headers.get("Authorization"))
+
+
+class RequiresScope:
+    def __init__(self, scope: str):
+        assert scope.replace(" ", "") == scope, "spaces are not allowed in a scope"
+        self.scope = scope
+
+    async def __call__(self, token: Token = Depends(get_token)) -> None:
+        if self.scope not in token.scope:
+            raise PermissionDenied(f"this operation requires '{self.scope}' scope")
+
+
+class OAuth2SPAClientSchema(OAuth2AuthorizationCodeBearer):
+    """A fastapi 'dependable' configuring the openapi schema for the
+    OAuth2 Authorization Code Flow with PKCE extension.
+
+    This includes the JWT Bearer token configuration.
+    """
+
+    def __init__(self, client: OAuth2SPAClientSettings):
+        super().__init__(
+            scheme_name="OAuth2 Authorization Code Flow with PKCE",
+            authorizationUrl=str(client.authorization_url),
+            tokenUrl=str(client.token_url),
+        )
+
+    async def __call__(self) -> None:
+        pass
+
+
+class JWTBearerTokenSchema(HTTPBearer):
+    """A fastapi 'dependable' configuring the openapi schema for JWT Bearer tokens.
+
+    Note: for the client-side OAuth2 flow, use OAuth2SPAClientSchema instead.
+    """
+
+    def __init__(self):
+        super().__init__(scheme_name="JWT Bearer token", bearerFormat="JWT")
+
+    async def __call__(self) -> None:
+        pass

+ 18 - 56
clean_python/fastapi/service.py

@@ -1,32 +1,27 @@
 # (c) Nelen & Schuurmans
 
-import logging
 from typing import Any
 from typing import Callable
 from typing import List
 from typing import Optional
 from typing import Set
 
-from asgiref.sync import sync_to_async
 from fastapi import Depends
 from fastapi import FastAPI
 from fastapi import Request
 from fastapi.exceptions import RequestValidationError
-from fastapi.security import HTTPBearer
-from fastapi.security import OAuth2AuthorizationCodeBearer
 from starlette.types import ASGIApp
 
 from clean_python import Conflict
+from clean_python import ctx
 from clean_python import DoesNotExist
 from clean_python import Gateway
 from clean_python import PermissionDenied
 from clean_python import Unauthorized
 from clean_python.oauth2 import OAuth2SPAClientSettings
-from clean_python.oauth2 import TokenVerifier
+from clean_python.oauth2 import Token
 from clean_python.oauth2 import TokenVerifierSettings
 
-from .context import ctx
-from .context import RequestMiddleware
 from .error_responses import BadRequest
 from .error_responses import conflict_handler
 from .error_responses import DefaultErrorResponse
@@ -40,63 +35,24 @@ from .fastapi_access_logger import FastAPIAccessLogger
 from .resource import APIVersion
 from .resource import clean_resources
 from .resource import Resource
-
-logger = logging.getLogger(__name__)
+from .security import get_token
+from .security import JWTBearerTokenSchema
+from .security import OAuth2SPAClientSchema
+from .security import set_verifier
 
 __all__ = ["Service"]
 
 
-class OAuth2WithClientDependable(OAuth2AuthorizationCodeBearer):
-    """A fastapi 'dependable' configuring OAuth2.
-
-    This does two things:
-    - Verify the token in each request
-    - (through FastAPI magic) add the scheme to the OpenAPI spec
-    """
-
-    def __init__(
-        self, settings: TokenVerifierSettings, client: OAuth2SPAClientSettings
-    ):
-        self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
-        super().__init__(
-            scheme_name="OAuth2 Authorization Code Flow with PKCE",
-            authorizationUrl=str(client.authorization_url),
-            tokenUrl=str(client.token_url),
-        )
-
-    async def __call__(self, request: Request) -> None:
-        ctx.claims = await self.verifier(request.headers.get("Authorization"))
-
-
-class OAuth2WithoutClientDependable(HTTPBearer):
-    """A fastapi 'dependable' configuring OAuth2.
-
-    This does one thing:
-    - Verify the token in each request
-    """
-
-    def __init__(self, settings: TokenVerifierSettings):
-        self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
-        super().__init__(scheme_name="JWT Bearer token", bearerFormat="JWT")
-
-    async def __call__(self, request: Request) -> None:
-        ctx.claims = await self.verifier(request.headers.get("Authorization"))
-
-
-def get_auth_kwargs(
-    auth: Optional[TokenVerifierSettings],
-    auth_client: Optional[OAuth2SPAClientSettings],
-) -> None:
-    if auth is None:
-        return {}
+def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> None:
     if auth_client is None:
         return {
-            "dependencies": [Depends(OAuth2WithoutClientDependable(settings=auth))],
+            "dependencies": [Depends(JWTBearerTokenSchema()), Depends(set_context)],
         }
     else:
         return {
             "dependencies": [
-                Depends(OAuth2WithClientDependable(settings=auth, client=auth_client))
+                Depends(OAuth2SPAClientSchema(client=auth_client)),
+                Depends(set_context),
             ],
             "swagger_ui_init_oauth": {
                 "clientId": auth_client.client_id,
@@ -105,6 +61,12 @@ def get_auth_kwargs(
         }
 
 
+async def set_context(request: Request, token: Token = Depends(get_token)) -> None:
+    ctx.path = request.url
+    ctx.user = token.user
+    ctx.tenant = token.tenant
+
+
 async def health_check():
     """Simple health check route"""
     return {"health": "OK"}
@@ -143,7 +105,6 @@ class Service:
                 hostname=hostname, gateway_override=access_logger_gateway
             )
         )
-        app.add_middleware(RequestMiddleware)
         app.get("/health", include_in_schema=False)(health_check)
         return app
 
@@ -186,6 +147,7 @@ class Service:
         on_startup: Optional[List[Callable[[], Any]]] = None,
         access_logger_gateway: Optional[Gateway] = None,
     ) -> ASGIApp:
+        set_verifier(auth)
         app = self._create_root_app(
             title=title,
             description=description,
@@ -196,7 +158,7 @@ class Service:
         kwargs = {
             "title": title,
             "description": description,
-            **get_auth_kwargs(auth, auth_client),
+            **get_auth_kwargs(auth_client),
         }
         versioned_apps = {
             v: self._create_versioned_app(v, **kwargs) for v in self.versions

+ 2 - 2
clean_python/oauth2/__init__.py

@@ -1,2 +1,2 @@
-from .claims import *  # NOQA
-from .oauth2 import *  # NOQA
+from .token import *  # NOQA
+from .token_verifier import *  # NOQA

+ 0 - 22
clean_python/oauth2/claims.py

@@ -1,22 +0,0 @@
-from typing import FrozenSet
-from typing import Optional
-
-from clean_python import ValueObject
-
-__all__ = ["Claims"]
-
-
-class Tenant(ValueObject):
-    id: int
-    name: str
-
-
-class User(ValueObject):
-    id: str
-    name: Optional[str]
-
-
-class Claims(ValueObject):
-    user: User
-    tenant: Optional[Tenant]
-    scope: FrozenSet[str]

+ 41 - 0
clean_python/oauth2/token.py

@@ -0,0 +1,41 @@
+from typing import Optional
+
+from pydantic import validator
+
+from clean_python import Json
+from clean_python import Scope
+from clean_python import Tenant
+from clean_python import User
+from clean_python import ValueObject
+
+__all__ = ["Token"]
+
+
+class Token(ValueObject):
+    claims: Json
+
+    @validator("claims")
+    def validate_claims(cls, v):
+        if not isinstance(v, dict):
+            return v
+        assert v.get("sub"), "missing 'sub' claim"
+        assert v.get("scope"), "missing 'scope' claim"
+        assert v.get("username"), "missing 'username' claim"
+        if v.get("tenant"):
+            assert v.get("tenant_name"), "missing 'tenant_name' claim"
+        return v
+
+    @property
+    def user(self) -> User:
+        return User(id=self.claims["sub"], name=self.claims.get("username"))
+
+    @property
+    def scope(self) -> Scope:
+        return frozenset(self.claims["scope"].split(" "))
+
+    @property
+    def tenant(self) -> Optional[Tenant]:
+        if self.claims.get("tenant"):
+            return Tenant(id=self.claims["tenant"], name=self.claims.get("tenant_name"))
+        else:
+            return None

+ 44 - 33
clean_python/oauth2/oauth2.py → clean_python/oauth2/token_verifier.py

@@ -1,5 +1,6 @@
 # (c) Nelen & Schuurmans
 
+import logging
 from typing import Dict
 from typing import FrozenSet
 from typing import List
@@ -10,15 +11,23 @@ from jwt import PyJWKClient
 from jwt.exceptions import PyJWTError
 from pydantic import AnyHttpUrl
 from pydantic import BaseModel
+from pydantic import ValidationError
 
 from clean_python import PermissionDenied
 from clean_python import Unauthorized
+from clean_python import User
 
-from .claims import Claims
-from .claims import Tenant
-from .claims import User
+from .token import Token
 
-__all__ = ["TokenVerifier", "TokenVerifierSettings", "OAuth2SPAClientSettings"]
+__all__ = [
+    "BaseTokenVerifier",
+    "TokenVerifier",
+    "NoAuthTokenVerifier",
+    "TokenVerifierSettings",
+    "OAuth2SPAClientSettings",
+]
+
+logger = logging.getLogger(__name__)
 
 
 class TokenVerifierSettings(BaseModel):
@@ -35,7 +44,17 @@ class OAuth2SPAClientSettings(BaseModel):
     authorization_url: AnyHttpUrl
 
 
-class TokenVerifier:
+class BaseTokenVerifier:
+    def __call__(self, authorization: Optional[str]) -> Token:
+        raise NotImplementedError()
+
+
+class NoAuthTokenVerifier(BaseTokenVerifier):
+    def __call__(self, authorization: Optional[str]) -> Token:
+        return Token(claims={"sub": "DEV", "username": "dev", "scope": "superuser"})
+
+
+class TokenVerifier(BaseTokenVerifier):
     """A class for verifying OAuth2 Access Tokens from AWS Cognito
 
     The verification steps followed are documented here:
@@ -47,26 +66,30 @@ class TokenVerifier:
     # allow 2 minutes leeway for verifying token expiry:
     LEEWAY = 120
 
-    def __init__(self, settings: TokenVerifierSettings):
+    def __init__(
+        self, settings: TokenVerifierSettings, logger: Optional[logging.Logger] = None
+    ):
         self.settings = settings
         self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
 
-    def __call__(self, authorization: Optional[str]) -> Claims:
+    def __call__(self, authorization: Optional[str]) -> Token:
         # Step 0: retrieve the token from the Authorization header
         # See https://tools.ietf.org/html/rfc6750#section-2.1,
         # Bearer is case-sensitive and there is exactly 1 separator after.
         if authorization is None:
+            logger.info("Missing Authorization header")
             raise Unauthorized()
         token = authorization[7:] if authorization.startswith("Bearer") else None
         if token is None:
+            logger.info("Authorization does not start with 'Bearer '")
             raise Unauthorized()
 
         # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
         # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
         try:
             key = self.get_key(token)  # JSON Web Key
-        except PyJWTError:
-            # logger.info("Token is invalid: %s", e)
+        except PyJWTError as e:
+            logger.info("Token is invalid: %s", e)
             raise Unauthorized()
         # Step 2: Validate the JWT signature and standard claims
         try:
@@ -80,19 +103,21 @@ class TokenVerifier:
                     "require": ["exp", "iss", "sub", "scope", "token_use"],
                 },
             )
-        except PyJWTError:
-            # logger.info("Token is invalid: %s", e)
+        except PyJWTError as e:
+            logger.info("Token is invalid: %s", e)
             raise Unauthorized()
         # Step 3: Verify additional claims. At this point, we have passed
         # verification, so unverified claims may be used safely.
-        user = self.parse_user(claims)
-        tenant = self.parse_tenant(claims)
-        scope = self.parse_scope(claims)
         self.verify_token_use(claims)
-        self.verify_scope(scope)
+        try:
+            token = Token(claims=claims)
+        except ValidationError as e:
+            logger.info("Token is invalid: %s", e)
+            raise Unauthorized()
+        self.verify_scope(token.scope)
         # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
-        self.authorize_user(user)
-        return Claims(user=user, scope=scope, tenant=tenant)
+        self.authorize_user(token.user)
+        return token
 
     def get_key(self, token) -> jwt.PyJWK:
         """Return the JSON Web KEY (JWK) corresponding to kid."""
@@ -101,7 +126,7 @@ class TokenVerifier:
     def verify_token_use(self, claims: Dict) -> None:
         """Check the token_use claim."""
         if claims["token_use"] != "access":
-            # logger.info("Token has invalid token_use claim: %s", claims["token_use"])
+            logger.info("Token has invalid token_use claim: %s", claims["token_use"])
             raise Unauthorized()
 
     def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
@@ -109,25 +134,11 @@ class TokenVerifier:
         if self.settings.scope is None:
             return
         if self.settings.scope not in claims_scope:
-            # logger.info("Token has invalid scope claim: %s", claims["scope"])
+            logger.info("Token is missing '%s' scope", self.settings.scope)
             raise Unauthorized()
 
     def authorize_user(self, user: User) -> None:
         if self.settings.admin_users is None:
             return
         if user.id not in self.settings.admin_users:
-            # logger.info("User with sub %s is not authorized", claims.get("sub"))
             raise PermissionDenied()
-
-    def parse_user(self, claims: Dict) -> User:
-        return User(id=claims["sub"], name=claims.get("username"))
-
-    def parse_scope(self, claims: Dict) -> FrozenSet[str]:
-        return frozenset(claims["scope"].split(" "))
-
-    def parse_tenant(self, claims: Dict) -> Optional[Tenant]:
-        if claims.get("tenant"):
-            tenant = Tenant(id=claims["tenant"], name=claims.get("tenant_name", ""))
-        else:
-            tenant = None
-        return tenant

+ 44 - 13
clean_python/sql/sql_gateway.py

@@ -8,6 +8,7 @@ from typing import Optional
 from typing import TypeVar
 
 import inject
+from sqlalchemy import and_
 from sqlalchemy import asc
 from sqlalchemy import delete
 from sqlalchemy import desc
@@ -24,6 +25,7 @@ from sqlalchemy.sql.expression import false
 
 from clean_python import AlreadyExists
 from clean_python import Conflict
+from clean_python import ctx
 from clean_python import DoesNotExist
 from clean_python import Filter
 from clean_python import Gateway
@@ -50,9 +52,12 @@ T = TypeVar("T", bound="SQLGateway")
 class SQLGateway(Gateway):
     table: Table
     nested: bool
+    multitenant: bool
 
     def __init__(
-        self, provider_override: Optional[SQLProvider] = None, nested: bool = False
+        self,
+        provider_override: Optional[SQLProvider] = None,
+        nested: bool = False,
     ):
         self.provider_override = provider_override
         self.nested = nested
@@ -61,8 +66,11 @@ class SQLGateway(Gateway):
     def provider(self):
         return self.provider_override or inject.instance(SQLDatabase)
 
-    def __init_subclass__(cls, table: Table) -> None:
+    def __init_subclass__(cls, table: Table, multitenant: bool = False) -> None:
         cls.table = table
+        if multitenant and not hasattr(table.c, "tenant"):
+            raise ValueError("Can't use a multitenant SQLGateway without tenant column")
+        cls.multitenant = multitenant
         super().__init_subclass__()
 
     def rows_to_dict(self, rows: List[Json]) -> List[Json]:
@@ -73,6 +81,8 @@ class SQLGateway(Gateway):
         result = {k: obj[k] for k in obj.keys() if k in known}
         if "id" in result and result["id"] is None:
             del result["id"]
+        if self.multitenant:
+            result["tenant"] = self.current_tenant
         return result
 
     @asynccontextmanager
@@ -83,6 +93,14 @@ class SQLGateway(Gateway):
             async with self.provider.transaction() as provider:
                 yield self.__class__(provider, nested=True)
 
+    @property
+    def current_tenant(self) -> Optional[int]:
+        if not self.multitenant:
+            return None
+        if ctx.tenant is None:
+            raise RuntimeError(f"{self.__class__} requires a tenant in the context")
+        return ctx.tenant.id
+
     async def get_related(self, items: List[Json]) -> None:
         pass
 
@@ -114,7 +132,7 @@ class SQLGateway(Gateway):
         id_ = item.get("id")
         if id_ is None:
             raise DoesNotExist("record", id_)
-        q = self.table.c.id == id_
+        q = self._id_filter_to_sql(id_)
         if if_unmodified_since is not None:
             q &= self.table.c.updated_at == if_unmodified_since
         query = (
@@ -137,7 +155,7 @@ class SQLGateway(Gateway):
     async def _select_for_update(self, id: int) -> Json:
         async with self.transaction() as transaction:
             result = await transaction.execute(
-                select(self.table).with_for_update().where(self.table.c.id == id),
+                select(self.table).with_for_update().where(self._id_filter_to_sql(id)),
             )
             if not result:
                 raise DoesNotExist("record", id)
@@ -157,7 +175,10 @@ class SQLGateway(Gateway):
         query = (
             insert(self.table)
             .values(**values)
-            .on_conflict_do_update(index_elements=["id"], set_=values)
+            .on_conflict_do_update(
+                index_elements=["id", "tenant"] if self.multitenant else ["id"],
+                set_=values,
+            )
             .returning(self.table)
         )
         async with self.transaction() as transaction:
@@ -167,13 +188,15 @@ class SQLGateway(Gateway):
 
     async def remove(self, id) -> bool:
         query = (
-            delete(self.table).where(self.table.c.id == id).returning(self.table.c.id)
+            delete(self.table)
+            .where(self._id_filter_to_sql(id))
+            .returning(self.table.c.id)
         )
         async with self.transaction() as transaction:
             result = await transaction.execute(query)
         return bool(result)
 
-    def _to_sqlalchemy_expression(self, filter: Filter) -> ColumnElement:
+    def _filter_to_sql(self, filter: Filter) -> ColumnElement:
         try:
             column = getattr(self.table.c, filter.field)
         except AttributeError:
@@ -185,12 +208,19 @@ class SQLGateway(Gateway):
         else:
             return column.in_(filter.values)
 
+    def _filters_to_sql(self, filters: List[Filter]) -> ColumnElement:
+        qs = [self._filter_to_sql(x) for x in filters]
+        if self.multitenant:
+            qs.append(self.table.c.tenant == self.current_tenant)
+        return and_(*qs)
+
+    def _id_filter_to_sql(self, id: int) -> ColumnElement:
+        return self._filters_to_sql([Filter(field="id", values=[id])])
+
     async def filter(
         self, filters: List[Filter], params: Optional[PageOptions] = None
     ) -> List[Json]:
-        query = select(self.table).where(
-            *[self._to_sqlalchemy_expression(x) for x in filters]
-        )
+        query = select(self.table).where(self._filters_to_sql(filters))
         if params is not None:
             sort = asc(params.order_by) if params.ascending else desc(params.order_by)
             query = query.order_by(sort).limit(params.limit).offset(params.offset)
@@ -203,7 +233,7 @@ class SQLGateway(Gateway):
         query = (
             select(func.count().label("count"))
             .select_from(self.table)
-            .where(*[self._to_sqlalchemy_expression(x) for x in filters])
+            .where(self._filters_to_sql(filters))
         )
         async with self.transaction() as transaction:
             return (await transaction.execute(query))[0]["count"]
@@ -212,7 +242,7 @@ class SQLGateway(Gateway):
         query = (
             select(true().label("exists"))
             .select_from(self.table)
-            .where(*[self._to_sqlalchemy_expression(x) for x in filters])
+            .where(self._filters_to_sql(filters))
             .limit(1)
         )
         async with self.transaction() as transaction:
@@ -257,6 +287,7 @@ class SQLGateway(Gateway):
                 ]
             }
         """
+        assert not self.multitenant
         for x in items:
             x[field_name] = []
         item_lut = {x["id"]: x for x in items}
@@ -308,7 +339,7 @@ class SQLGateway(Gateway):
                 ]
             }
         """
-
+        assert not self.multitenant
         # list existing related objects
         existing_lut = {
             x["id"]: x

+ 0 - 0
tests/test_fastapi_access_logger.py → tests/fastapi/test_fastapi_access_logger.py


+ 0 - 0
tests/test_request_query.py → tests/fastapi/test_request_query.py


+ 20 - 0
tests/test_resource.py → tests/fastapi/test_resource.py

@@ -3,6 +3,7 @@ from fastapi.routing import APIRouter
 
 from clean_python.fastapi import APIVersion
 from clean_python.fastapi import get
+from clean_python.fastapi import RequiresScope
 from clean_python.fastapi import Resource
 from clean_python.fastapi import Stability
 from clean_python.fastapi import v
@@ -54,6 +55,7 @@ def test_get_router():
     assert route.name == "v1/get_test"
     assert route.tags == ["testing"]
     assert route.methods == {"GET"}
+    assert len(route.dependencies) == 0
     # 'self' is missing from the parameters
     assert list(route.param_convertors.keys()) == ["id"]
 
@@ -134,3 +136,21 @@ def test_get_less_stable_no_subclass():
 
     with pytest.raises(RuntimeError):
         resources[v(1)].get_less_stable(resources)
+
+
+def test_get_router_with_scope():
+    class TestResource(Resource, version=v(1), name="testing"):
+        @get("/foo/{id}", scope="foo")
+        def get_test(self, id: int):
+            return "ok"
+
+    resource = TestResource()
+
+    router = resource.get_router(v(1))
+
+    assert len(router.routes) == 1
+
+    route = router.routes[0]
+    (dep,) = route.dependencies
+    assert isinstance(dep.dependency, RequiresScope)
+    assert dep.dependency.scope == "foo"

+ 38 - 0
tests/fastapi/test_security.py

@@ -0,0 +1,38 @@
+import pytest
+
+from clean_python import PermissionDenied
+from clean_python.fastapi import RequiresScope
+from clean_python.oauth2 import Token
+
+
+@pytest.fixture
+def token() -> Token:
+    return Token(claims={"sub": "abc123", "scope": "a b", "username": "foo"})
+
+
+@pytest.fixture
+def token_multitenant() -> Token:
+    return Token(
+        claims={
+            "sub": "abc123",
+            "scope": "a b",
+            "username": "foo",
+            "tenant": 1,
+            "tenant_name": "bar",
+        }
+    )
+
+
+@pytest.mark.parametrize("scope", ["a", "b"])
+async def test_requires_scope(token, scope):
+    await RequiresScope(scope)(token)
+
+
+async def test_requires_scope_err(token):
+    with pytest.raises(PermissionDenied):
+        await RequiresScope("c")(token)
+
+
+def test_requries_scope_no_spaces_allowed():
+    with pytest.raises(AssertionError):
+        RequiresScope("c ")

+ 0 - 0
tests/test_service.py → tests/fastapi/test_service.py


+ 1 - 0
tests/oauth2/conftest.py

@@ -44,6 +44,7 @@ def jwk_patched(public_key):
 def token_generator(private_key):
     default_claims = {
         "sub": "foo",
+        "username": "piet",
         "iss": "https://some/auth/server",
         "scope": "user",
         "token_use": "access",

+ 49 - 26
tests/oauth2/test_service_auth.py

@@ -3,6 +3,7 @@ from http import HTTPStatus
 import pytest
 from fastapi.testclient import TestClient
 
+from clean_python import ctx
 from clean_python import InMemoryGateway
 from clean_python.fastapi import get
 from clean_python.fastapi import Resource
@@ -17,14 +18,35 @@ class FooResource(Resource, version=v(1), name="testing"):
     def testing(self):
         return "ok"
 
+    @get("/bar", scope="admin")
+    def scoped(self):
+        return "ok"
 
-@pytest.fixture
-def app(settings: TokenVerifierSettings):
+    @get("/context")
+    def context(self):
+        return {
+            "path": str(ctx.path),
+            "user": ctx.user,
+            "tenant": ctx.tenant,
+        }
+
+
+@pytest.fixture(params=["noclient", "client"])
+def app(request, settings: TokenVerifierSettings):
+    if request.param == "noclient":
+        auth_client = None
+    elif request.param == "client":
+        auth_client = OAuth2SPAClientSettings(
+            client_id="123",
+            token_url="https://server/token",
+            authorization_url="https://server/token",
+        )
     return Service(FooResource()).create_app(
         title="test",
         description="testing",
         hostname="testserver",
         auth=settings,
+        auth_client=auth_client,
         access_logger_gateway=InMemoryGateway([]),
     )
 
@@ -51,39 +73,40 @@ def test_ok(app, client: TestClient, token_generator):
     assert response.status_code == HTTPStatus.OK
 
 
-@pytest.fixture
-def app2(settings: TokenVerifierSettings):
-    return Service(FooResource()).create_app(
-        title="test",
-        description="testing",
-        hostname="testserver",
-        auth=settings,
-        auth_client=OAuth2SPAClientSettings(
-            client_id="123",
-            token_url="https://server/token",
-            authorization_url="https://server/token",
-        ),
-        access_logger_gateway=InMemoryGateway([]),
+@pytest.mark.usefixtures("jwk_patched")
+def test_scoped_ok(app, client: TestClient, token_generator):
+    response = client.get(
+        app.url_path_for("v1/scoped"),
+        headers={"Authorization": "Bearer " + token_generator(scope="user admin")},
     )
 
-
-@pytest.fixture
-def client2(app):
-    return TestClient(app)
+    assert response.status_code == HTTPStatus.OK
 
 
 @pytest.mark.usefixtures("jwk_patched")
-def test_no_header2(app2, client2: TestClient):
-    response = client2.get(app2.url_path_for("v1/testing"))
+def test_scoped_forbidden(app, client: TestClient, token_generator):
+    response = client.get(
+        app.url_path_for("v1/scoped"),
+        headers={"Authorization": "Bearer " + token_generator(scope="user")},
+    )
 
-    assert response.status_code == HTTPStatus.UNAUTHORIZED
+    assert response.status_code == HTTPStatus.FORBIDDEN
 
 
 @pytest.mark.usefixtures("jwk_patched")
-def test_ok2(app2, client2: TestClient, token_generator):
-    response = client2.get(
-        app2.url_path_for("v1/testing"),
-        headers={"Authorization": "Bearer " + token_generator()},
+def test_context(app, client: TestClient, token_generator):
+    response = client.get(
+        app.url_path_for("v1/context"),
+        headers={
+            "Authorization": "Bearer " + token_generator(tenant=2, tenant_name="bar")
+        },
     )
 
     assert response.status_code == HTTPStatus.OK
+    assert response.json() == {
+        "path": "http://testserver/v1/context",
+        "user": {"id": "foo", "name": "piet"},
+        "tenant": {"id": 2, "name": "bar"},
+    }
+    assert ctx.user.id != "foo"
+    assert ctx.tenant is None

+ 64 - 0
tests/oauth2/test_token.py

@@ -0,0 +1,64 @@
+import pytest
+from pydantic import ValidationError
+
+from clean_python import Tenant
+from clean_python import User
+from clean_python.oauth2 import Token
+
+
+@pytest.fixture
+def claims():
+    return {"sub": "abc123", "scope": "a b", "username": "foo"}
+
+
+@pytest.fixture
+def claims_multitenant():
+    return {
+        "sub": "abc123",
+        "scope": "a b",
+        "username": "foo",
+        "tenant": 1,
+        "tenant_name": "bar",
+    }
+
+
+def test_init(claims):
+    Token(claims=claims)
+
+
+def test_init_multitenant(claims_multitenant):
+    Token(claims=claims_multitenant)
+
+
+@pytest.mark.parametrize(
+    "claims",
+    [
+        {"scope": "", "username": "foo"},
+        {"sub": "abc123", "username": "foo"},
+        {"sub": "abc123", "scope": ""},
+        {"sub": "abc123", "scope": "", "username": "foo", "tenant": 1},
+    ],
+)
+def test_init_err(claims):
+    with pytest.raises(ValidationError):
+        Token(claims=claims)
+
+
+def test_user(claims):
+    actual = Token(claims=claims).user
+    assert actual == User(id="abc123", name="foo")
+
+
+def test_scope(claims):
+    actual = Token(claims=claims).scope
+    assert actual == frozenset({"a", "b"})
+
+
+def test_tenant(claims_multitenant):
+    actual = Token(claims=claims_multitenant).tenant
+    assert actual == Tenant(id=1, name="bar")
+
+
+def test_no_tenant(claims):
+    actual = Token(claims=claims).tenant
+    assert actual is None

+ 8 - 24
tests/oauth2/test_oauth2.py → tests/oauth2/test_verifier.py

@@ -6,6 +6,7 @@ import pytest
 
 from clean_python import PermissionDenied
 from clean_python import Unauthorized
+from clean_python.oauth2 import Token
 from clean_python.oauth2 import TokenVerifier
 
 
@@ -16,32 +17,14 @@ def patched_verifier(jwk_patched, settings):
 
 def test_verifier_ok(patched_verifier, token_generator):
     token = token_generator()
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.user.id == "foo"
-    assert verified_claims.tenant is None
-    assert verified_claims.scope == {"user"}
-
-    patched_verifier.get_key.assert_called_once_with(token)
-
-
-def test_verifier_ok_with_username(patched_verifier, token_generator):
-    token = token_generator(username="sinterklaas")
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.user.name == "sinterklaas"
+    verified_token = patched_verifier("Bearer " + token)
 
+    assert isinstance(verified_token, Token)
+    assert verified_token.user.id == "foo"
+    assert verified_token.tenant is None
+    assert verified_token.scope == {"user"}
 
-def test_verifier_ok_with_tenant(patched_verifier, token_generator):
-    token = token_generator(tenant="15")
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.tenant.id == 15
-    assert verified_claims.tenant.name == ""
-
-
-def test_verifier_ok_with_tenant_and_name(patched_verifier, token_generator):
-    token = token_generator(tenant=15, tenant_name="foo")
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.tenant.id == 15
-    assert verified_claims.tenant.name == "foo"
+    patched_verifier.get_key.assert_called_once_with(token)
 
 
 def test_verifier_exp_leeway(patched_verifier, token_generator):
@@ -67,6 +50,7 @@ def test_verifier_multiple_scopes(patched_verifier, token_generator, settings):
         {"token_use": "id"},
         {"token_use": None},
         {"sub": None},
+        {"username": None},
     ],
 )
 def test_verifier_bad(patched_verifier, token_generator, claim_overrides):

+ 0 - 0
tests/test_sql_gateway.py → tests/sql/test_sql_gateway.py


+ 179 - 0
tests/sql/test_sql_gateway_multitenant.py

@@ -0,0 +1,179 @@
+import pytest
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+from sqlalchemy import Text
+
+from clean_python import ctx
+from clean_python import Filter
+from clean_python import Tenant
+from clean_python.sql import SQLGateway
+from clean_python.sql.testing import assert_query_equal
+from clean_python.sql.testing import FakeSQLDatabase
+
+writer = Table(
+    "writer",
+    MetaData(),
+    Column("id", Integer, primary_key=True, autoincrement=True),
+    Column("value", Text, nullable=False),
+    Column("updated_at", DateTime(timezone=True), nullable=False),
+    Column("tenant", Integer, nullable=False),
+)
+
+
+ALL_FIELDS = "writer.id, writer.value, writer.updated_at, writer.tenant"
+
+
+class TstSQLGateway(SQLGateway, table=writer, multitenant=True):
+    pass
+
+
+@pytest.fixture
+def sql_gateway():
+    return TstSQLGateway(FakeSQLDatabase())
+
+
+@pytest.fixture
+def tenant():
+    ctx.tenant = Tenant(id=2, name="foo")
+    return ctx.tenant
+
+
+async def test_no_tenant(sql_gateway):
+    with pytest.raises(RuntimeError):
+        await sql_gateway.filter([])
+    assert len(sql_gateway.provider.queries) == 0
+
+
+async def test_missing_tenant_column():
+    table = Table(
+        "notenant",
+        MetaData(),
+        Column("id", Integer, primary_key=True, autoincrement=True),
+    )
+
+    with pytest.raises(ValueError):
+
+        class Foo(SQLGateway, table=table, multitenant=True):
+            pass
+
+
+async def test_filter(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
+    assert await sql_gateway.filter([Filter(field="id", values=[1])]) == [
+        {"id": 2, "value": "foo"}
+    ]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
+    )
+
+
+async def test_count(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"count": 4}]
+    assert await sql_gateway.count([Filter(field="id", values=[1])]) == 4
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT count(*) AS count FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
+    )
+
+
+async def test_add(sql_gateway, tenant):
+    records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
+    sql_gateway.provider.result.return_value = records
+    assert await sql_gateway.add({"value": "foo"}) == records[0]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"INSERT INTO writer (value, tenant) VALUES ('foo', {tenant.id}) RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_update(sql_gateway, tenant):
+    records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
+    sql_gateway.provider.result.return_value = records
+    assert await sql_gateway.update({"id": 2, "value": "foo"}) == records[0]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"UPDATE writer SET id=2, value='foo', tenant={tenant.id} "
+            f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
+            f"RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_remove(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"id": 2}]
+    assert (await sql_gateway.remove(2)) is True
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"DELETE FROM writer WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
+            f"RETURNING writer.id"
+        ),
+    )
+
+
+async def test_upsert(sql_gateway, tenant):
+    record = {"id": 2, "value": "foo", "tenant": tenant.id}
+    sql_gateway.provider.result.return_value = [record]
+    assert await sql_gateway.upsert({"id": 2, "value": "foo"}) == record
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"INSERT INTO writer (id, value, tenant) VALUES (2, 'foo', {tenant.id}) "
+            f"ON CONFLICT (id, tenant) DO UPDATE SET "
+            f"id = %(param_1)s, value = %(param_2)s, tenant = %(param_3)s "
+            f"RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_update_transactional(sql_gateway, tenant):
+    existing = {"id": 2, "value": "foo", "tenant": tenant.id}
+    expected = {"id": 2, "value": "bar", "tenant": tenant.id}
+    sql_gateway.provider.result.side_effect = ([existing], [expected])
+    actual = await sql_gateway.update_transactional(
+        2, lambda x: {"id": x["id"], "value": "bar"}
+    )
+    assert actual == expected
+
+    (queries,) = sql_gateway.provider.queries
+    assert len(queries) == 2
+    assert_query_equal(
+        queries[0],
+        (
+            f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 "
+            f"AND writer.tenant = {tenant.id} FOR UPDATE"
+        ),
+    )
+    assert_query_equal(
+        queries[1],
+        (
+            f"UPDATE writer SET id=2, value='bar', tenant={tenant.id} "
+            f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_exists(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"exists": True}]
+    assert await sql_gateway.exists([Filter(field="id", values=[1])]) is True
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"SELECT true AS exists FROM writer "
+            f"WHERE writer.id = 1 AND writer.tenant = {tenant.id} LIMIT 1"
+        ),
+    )

+ 40 - 0
tests/test_context.py

@@ -0,0 +1,40 @@
+import asyncio
+import os
+
+from pydantic import HttpUrl
+
+from clean_python import ctx
+from clean_python import Tenant
+from clean_python import User
+
+
+def test_default_context():
+    assert str(ctx.path) == "file://" + os.getcwd()
+    assert ctx.user.id == "ANONYMOUS"
+    assert ctx.user.name == "anonymous"
+    assert ctx.tenant is None
+
+
+async def test_task_isolation():
+    async def get_set(user):
+        ctx.user = user
+        asyncio.sleep(0.01)
+        assert ctx.user == user
+
+    await asyncio.gather(*[get_set(User(id=str(i), name="piet")) for i in range(10)])
+    assert ctx.user.id == "ANONYMOUS"
+
+
+async def test_tenant():
+    tenant = Tenant(id=2, name="foo")
+    ctx.tenant = tenant
+    assert ctx.tenant == tenant
+
+    ctx.tenant = None
+    assert ctx.tenant is None
+
+
+async def test_path():
+    url = HttpUrl("http://testserver/foo?a=b")
+    ctx.path = url
+    assert ctx.path == url