Browse Source

Add optional scope check to resource methods (#6)

Casper van der Wel 1 year ago
parent
commit
be95f47eb8

+ 7 - 1
CHANGES.md

@@ -4,7 +4,13 @@
 0.2.3 (unreleased)
 0.2.3 (unreleased)
 ------------------
 ------------------
 
 
-- Nothing changed yet.
+- Add `scope` kwarg to http_method decorators (get, post, etc.)
+
+- Moved the `Context` (`ctx`) to `clean_python.base` and changed its attributes to
+  `path`, `user` and `tenant`.
+
+- The `SQLGateway` can now be constructed with `multitenant=True` which makes it
+  automatically filter the `tenant` column with the current `ctx.tenant`.
 
 
 
 
 0.2.2 (2023-08-03)
 0.2.2 (2023-08-03)

+ 1 - 0
clean_python/base/domain/__init__.py

@@ -1,3 +1,4 @@
+from .context import *  # NOQA
 from .domain_event import *  # NOQA
 from .domain_event import *  # NOQA
 from .domain_service import *  # NOQA
 from .domain_service import *  # NOQA
 from .exceptions import *  # NOQA
 from .exceptions import *  # NOQA

+ 79 - 0
clean_python/base/domain/context.py

@@ -0,0 +1,79 @@
+# (c) Nelen & Schuurmans
+
+import os
+from contextvars import ContextVar
+from typing import FrozenSet
+from typing import Optional
+
+from pydantic import AnyUrl
+from pydantic import FileUrl
+
+from .value_object import ValueObject
+
+__all__ = ["ctx", "User", "Tenant", "Scope"]
+
+
+class User(ValueObject):
+    id: str
+    name: str
+
+
+Scope = FrozenSet[str]
+
+
+class Tenant(ValueObject):
+    id: int
+    name: str
+
+
+class Context:
+    """Provide global access to some contextual properties.
+
+    The implementation makes use of python's contextvars, which automatically integrates
+    with asyncio tasks (so that each task runs in its own context). This makes sure that
+    every request-response cycle is isolated.
+    """
+
+    def __init__(self):
+        self._path_value: ContextVar[AnyUrl] = ContextVar(
+            "path_value",
+            default=FileUrl.build(scheme="file", host="/", path=os.getcwd()),
+        )
+        self._user_value: ContextVar[User] = ContextVar(
+            "user_value", default=User(id="ANONYMOUS", name="anonymous")
+        )
+        self._tenant_value: ContextVar[Optional[Tenant]] = ContextVar(
+            "tenant_value", default=None
+        )
+
+    def reset(self):
+        self._path_value.reset()
+        self._user_value.reset()
+        self._tenant_value.reset()
+
+    @property
+    def path(self) -> AnyUrl:
+        return self._path_value.get()
+
+    @path.setter
+    def path(self, value: AnyUrl) -> None:
+        self._path_value.set(value)
+
+    @property
+    def user(self) -> User:
+        return self._user_value.get()
+
+    @user.setter
+    def user(self, value: User) -> None:
+        self._user_value.set(value)
+
+    @property
+    def tenant(self) -> Optional[Tenant]:
+        return self._tenant_value.get()
+
+    @tenant.setter
+    def tenant(self, value: Optional[Tenant]) -> None:
+        self._tenant_value.set(value)
+
+
+ctx = Context()

+ 1 - 1
clean_python/fastapi/__init__.py

@@ -1,6 +1,6 @@
-from .context import *  # NOQA
 from .error_responses import *  # NOQA
 from .error_responses import *  # NOQA
 from .fastapi_access_logger import *  # NOQA
 from .fastapi_access_logger import *  # NOQA
 from .request_query import *  # NOQA
 from .request_query import *  # NOQA
 from .resource import *  # NOQA
 from .resource import *  # NOQA
+from .security import *  # NOQA
 from .service import *  # NOQA
 from .service import *  # NOQA

+ 0 - 53
clean_python/fastapi/context.py

@@ -1,53 +0,0 @@
-# (c) Nelen & Schuurmans
-
-from contextvars import ContextVar
-
-from fastapi import Request
-
-from ..oauth2 import Claims
-
-__all__ = ["ctx", "RequestMiddleware"]
-
-
-class Context:
-    def __init__(self):
-        self._request_value: ContextVar[Request] = ContextVar("request_value")
-        self._claims_value: ContextVar[Claims] = ContextVar("claims_value")
-
-    @property
-    def request(self) -> Request:
-        return self._request_value.get()
-
-    @request.setter
-    def request(self, value: Request) -> None:
-        self._request_value.set(value)
-
-    @property
-    def claims(self) -> Claims:
-        return self._claims_value.get()
-
-    @claims.setter
-    def claims(self, value: Claims) -> None:
-        self._claims_value.set(value)
-
-
-ctx = Context()
-
-
-class RequestMiddleware:
-    """Save the current request in a context variable.
-
-    We were experiencing database connections piling up until PostgreSQL's
-    max_connections was hit, which has to do with BaseHTTPMiddleware not
-    interacting properly with context variables. For more details, see:
-    https://github.com/tiangolo/fastapi/issues/4719. Writing this
-    middleware as generic ASGI middleware fixes the problem.
-    """
-
-    def __init__(self, app):
-        self.app = app
-
-    async def __call__(self, scope, receive, send):
-        if scope["type"] == "http":
-            ctx.request = Request(scope, receive)
-        await self.app(scope, receive, send)

+ 1 - 1
clean_python/fastapi/error_responses.py

@@ -80,5 +80,5 @@ async def unauthorized_handler(request: Request, exc: Unauthorized):
 async def permission_denied_handler(request: Request, exc: PermissionDenied):
 async def permission_denied_handler(request: Request, exc: PermissionDenied):
     return JSONResponse(
     return JSONResponse(
         status_code=status.HTTP_403_FORBIDDEN,
         status_code=status.HTTP_403_FORBIDDEN,
-        content={"message": "Permission denied"},
+        content={"message": "Permission denied", "detail": str(exc)},
     )
     )

+ 10 - 3
clean_python/fastapi/resource.py

@@ -10,10 +10,13 @@ from typing import Optional
 from typing import Sequence
 from typing import Sequence
 from typing import Type
 from typing import Type
 
 
+from fastapi import Depends
 from fastapi.routing import APIRouter
 from fastapi.routing import APIRouter
 
 
 from clean_python import ValueObject
 from clean_python import ValueObject
 
 
+from .security import RequiresScope
+
 __all__ = [
 __all__ = [
     "Resource",
     "Resource",
     "get",
     "get",
@@ -71,12 +74,12 @@ class APIVersion(ValueObject):
         return APIVersion(version=self.version, stability=self.stability.decrease())
         return APIVersion(version=self.version, stability=self.stability.decrease())
 
 
 
 
-def http_method(path: str, **route_options):
+def http_method(path: str, scope: Optional[str] = None, **route_options):
     def wrapper(unbound_method: Callable[..., Any]):
     def wrapper(unbound_method: Callable[..., Any]):
         setattr(
         setattr(
             unbound_method,
             unbound_method,
             "http_method",
             "http_method",
-            (path, route_options),
+            (path, scope, route_options),
         )
         )
         return unbound_method
         return unbound_method
 
 
@@ -160,7 +163,7 @@ class Resource:
         router = APIRouter()
         router = APIRouter()
         operation_ids = set()
         operation_ids = set()
         for endpoint in self._endpoints():
         for endpoint in self._endpoints():
-            path, route_options = endpoint.http_method
+            path, scope, route_options = endpoint.http_method
             operation_id = endpoint.__name__
             operation_id = endpoint.__name__
             if operation_id in operation_ids:
             if operation_id in operation_ids:
                 raise RuntimeError(
                 raise RuntimeError(
@@ -170,6 +173,10 @@ class Resource:
             # The 'name' is used for reverse lookups (request.path_for): include the
             # The 'name' is used for reverse lookups (request.path_for): include the
             # version prefix so that we can uniquely refer to an operation.
             # version prefix so that we can uniquely refer to an operation.
             name = version.prefix + "/" + endpoint.__name__
             name = version.prefix + "/" + endpoint.__name__
+            # 'scope' is implemented using FastAPI's dependency injection system
+            if scope is not None:
+                route_options.setdefault("dependencies", [])
+                route_options["dependencies"].append(Depends(RequiresScope(scope)))
             router.add_api_route(
             router.add_api_route(
                 path,
                 path,
                 endpoint,
                 endpoint,

+ 82 - 0
clean_python/fastapi/security.py

@@ -0,0 +1,82 @@
+from typing import Optional
+
+from fastapi import Depends
+from fastapi import Request
+from fastapi.security import HTTPBearer
+from fastapi.security import OAuth2AuthorizationCodeBearer
+
+from clean_python import PermissionDenied
+from clean_python.oauth2 import BaseTokenVerifier
+from clean_python.oauth2 import NoAuthTokenVerifier
+from clean_python.oauth2 import OAuth2SPAClientSettings
+from clean_python.oauth2 import Token
+from clean_python.oauth2 import TokenVerifier
+from clean_python.oauth2 import TokenVerifierSettings
+
+__all__ = ["get_token", "RequiresScope"]
+
+verifier: Optional[BaseTokenVerifier] = None
+
+
+def clear_verifier() -> None:
+    global verifier
+
+    verifier = None
+
+
+def set_verifier(settings: Optional[TokenVerifierSettings]) -> None:
+    global verifier
+
+    if settings is None:
+        verifier = NoAuthTokenVerifier()
+    else:
+        verifier = TokenVerifier(settings=settings)
+
+
+def get_token(request: Request) -> Token:
+    """A fastapi 'dependable' yielding the validated token"""
+    global verifier
+
+    assert verifier is not None
+    return verifier(request.headers.get("Authorization"))
+
+
+class RequiresScope:
+    def __init__(self, scope: str):
+        assert scope.replace(" ", "") == scope, "spaces are not allowed in a scope"
+        self.scope = scope
+
+    async def __call__(self, token: Token = Depends(get_token)) -> None:
+        if self.scope not in token.scope:
+            raise PermissionDenied(f"this operation requires '{self.scope}' scope")
+
+
+class OAuth2SPAClientSchema(OAuth2AuthorizationCodeBearer):
+    """A fastapi 'dependable' configuring the openapi schema for the
+    OAuth2 Authorization Code Flow with PKCE extension.
+
+    This includes the JWT Bearer token configuration.
+    """
+
+    def __init__(self, client: OAuth2SPAClientSettings):
+        super().__init__(
+            scheme_name="OAuth2 Authorization Code Flow with PKCE",
+            authorizationUrl=str(client.authorization_url),
+            tokenUrl=str(client.token_url),
+        )
+
+    async def __call__(self) -> None:
+        pass
+
+
+class JWTBearerTokenSchema(HTTPBearer):
+    """A fastapi 'dependable' configuring the openapi schema for JWT Bearer tokens.
+
+    Note: for the client-side OAuth2 flow, use OAuth2SPAClientSchema instead.
+    """
+
+    def __init__(self):
+        super().__init__(scheme_name="JWT Bearer token", bearerFormat="JWT")
+
+    async def __call__(self) -> None:
+        pass

+ 18 - 56
clean_python/fastapi/service.py

@@ -1,32 +1,27 @@
 # (c) Nelen & Schuurmans
 # (c) Nelen & Schuurmans
 
 
-import logging
 from typing import Any
 from typing import Any
 from typing import Callable
 from typing import Callable
 from typing import List
 from typing import List
 from typing import Optional
 from typing import Optional
 from typing import Set
 from typing import Set
 
 
-from asgiref.sync import sync_to_async
 from fastapi import Depends
 from fastapi import Depends
 from fastapi import FastAPI
 from fastapi import FastAPI
 from fastapi import Request
 from fastapi import Request
 from fastapi.exceptions import RequestValidationError
 from fastapi.exceptions import RequestValidationError
-from fastapi.security import HTTPBearer
-from fastapi.security import OAuth2AuthorizationCodeBearer
 from starlette.types import ASGIApp
 from starlette.types import ASGIApp
 
 
 from clean_python import Conflict
 from clean_python import Conflict
+from clean_python import ctx
 from clean_python import DoesNotExist
 from clean_python import DoesNotExist
 from clean_python import Gateway
 from clean_python import Gateway
 from clean_python import PermissionDenied
 from clean_python import PermissionDenied
 from clean_python import Unauthorized
 from clean_python import Unauthorized
 from clean_python.oauth2 import OAuth2SPAClientSettings
 from clean_python.oauth2 import OAuth2SPAClientSettings
-from clean_python.oauth2 import TokenVerifier
+from clean_python.oauth2 import Token
 from clean_python.oauth2 import TokenVerifierSettings
 from clean_python.oauth2 import TokenVerifierSettings
 
 
-from .context import ctx
-from .context import RequestMiddleware
 from .error_responses import BadRequest
 from .error_responses import BadRequest
 from .error_responses import conflict_handler
 from .error_responses import conflict_handler
 from .error_responses import DefaultErrorResponse
 from .error_responses import DefaultErrorResponse
@@ -40,63 +35,24 @@ from .fastapi_access_logger import FastAPIAccessLogger
 from .resource import APIVersion
 from .resource import APIVersion
 from .resource import clean_resources
 from .resource import clean_resources
 from .resource import Resource
 from .resource import Resource
-
-logger = logging.getLogger(__name__)
+from .security import get_token
+from .security import JWTBearerTokenSchema
+from .security import OAuth2SPAClientSchema
+from .security import set_verifier
 
 
 __all__ = ["Service"]
 __all__ = ["Service"]
 
 
 
 
-class OAuth2WithClientDependable(OAuth2AuthorizationCodeBearer):
-    """A fastapi 'dependable' configuring OAuth2.
-
-    This does two things:
-    - Verify the token in each request
-    - (through FastAPI magic) add the scheme to the OpenAPI spec
-    """
-
-    def __init__(
-        self, settings: TokenVerifierSettings, client: OAuth2SPAClientSettings
-    ):
-        self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
-        super().__init__(
-            scheme_name="OAuth2 Authorization Code Flow with PKCE",
-            authorizationUrl=str(client.authorization_url),
-            tokenUrl=str(client.token_url),
-        )
-
-    async def __call__(self, request: Request) -> None:
-        ctx.claims = await self.verifier(request.headers.get("Authorization"))
-
-
-class OAuth2WithoutClientDependable(HTTPBearer):
-    """A fastapi 'dependable' configuring OAuth2.
-
-    This does one thing:
-    - Verify the token in each request
-    """
-
-    def __init__(self, settings: TokenVerifierSettings):
-        self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
-        super().__init__(scheme_name="JWT Bearer token", bearerFormat="JWT")
-
-    async def __call__(self, request: Request) -> None:
-        ctx.claims = await self.verifier(request.headers.get("Authorization"))
-
-
-def get_auth_kwargs(
-    auth: Optional[TokenVerifierSettings],
-    auth_client: Optional[OAuth2SPAClientSettings],
-) -> None:
-    if auth is None:
-        return {}
+def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> None:
     if auth_client is None:
     if auth_client is None:
         return {
         return {
-            "dependencies": [Depends(OAuth2WithoutClientDependable(settings=auth))],
+            "dependencies": [Depends(JWTBearerTokenSchema()), Depends(set_context)],
         }
         }
     else:
     else:
         return {
         return {
             "dependencies": [
             "dependencies": [
-                Depends(OAuth2WithClientDependable(settings=auth, client=auth_client))
+                Depends(OAuth2SPAClientSchema(client=auth_client)),
+                Depends(set_context),
             ],
             ],
             "swagger_ui_init_oauth": {
             "swagger_ui_init_oauth": {
                 "clientId": auth_client.client_id,
                 "clientId": auth_client.client_id,
@@ -105,6 +61,12 @@ def get_auth_kwargs(
         }
         }
 
 
 
 
+async def set_context(request: Request, token: Token = Depends(get_token)) -> None:
+    ctx.path = request.url
+    ctx.user = token.user
+    ctx.tenant = token.tenant
+
+
 async def health_check():
 async def health_check():
     """Simple health check route"""
     """Simple health check route"""
     return {"health": "OK"}
     return {"health": "OK"}
@@ -143,7 +105,6 @@ class Service:
                 hostname=hostname, gateway_override=access_logger_gateway
                 hostname=hostname, gateway_override=access_logger_gateway
             )
             )
         )
         )
-        app.add_middleware(RequestMiddleware)
         app.get("/health", include_in_schema=False)(health_check)
         app.get("/health", include_in_schema=False)(health_check)
         return app
         return app
 
 
@@ -186,6 +147,7 @@ class Service:
         on_startup: Optional[List[Callable[[], Any]]] = None,
         on_startup: Optional[List[Callable[[], Any]]] = None,
         access_logger_gateway: Optional[Gateway] = None,
         access_logger_gateway: Optional[Gateway] = None,
     ) -> ASGIApp:
     ) -> ASGIApp:
+        set_verifier(auth)
         app = self._create_root_app(
         app = self._create_root_app(
             title=title,
             title=title,
             description=description,
             description=description,
@@ -196,7 +158,7 @@ class Service:
         kwargs = {
         kwargs = {
             "title": title,
             "title": title,
             "description": description,
             "description": description,
-            **get_auth_kwargs(auth, auth_client),
+            **get_auth_kwargs(auth_client),
         }
         }
         versioned_apps = {
         versioned_apps = {
             v: self._create_versioned_app(v, **kwargs) for v in self.versions
             v: self._create_versioned_app(v, **kwargs) for v in self.versions

+ 2 - 2
clean_python/oauth2/__init__.py

@@ -1,2 +1,2 @@
-from .claims import *  # NOQA
-from .oauth2 import *  # NOQA
+from .token import *  # NOQA
+from .token_verifier import *  # NOQA

+ 0 - 22
clean_python/oauth2/claims.py

@@ -1,22 +0,0 @@
-from typing import FrozenSet
-from typing import Optional
-
-from clean_python import ValueObject
-
-__all__ = ["Claims"]
-
-
-class Tenant(ValueObject):
-    id: int
-    name: str
-
-
-class User(ValueObject):
-    id: str
-    name: Optional[str]
-
-
-class Claims(ValueObject):
-    user: User
-    tenant: Optional[Tenant]
-    scope: FrozenSet[str]

+ 41 - 0
clean_python/oauth2/token.py

@@ -0,0 +1,41 @@
+from typing import Optional
+
+from pydantic import validator
+
+from clean_python import Json
+from clean_python import Scope
+from clean_python import Tenant
+from clean_python import User
+from clean_python import ValueObject
+
+__all__ = ["Token"]
+
+
+class Token(ValueObject):
+    claims: Json
+
+    @validator("claims")
+    def validate_claims(cls, v):
+        if not isinstance(v, dict):
+            return v
+        assert v.get("sub"), "missing 'sub' claim"
+        assert v.get("scope"), "missing 'scope' claim"
+        assert v.get("username"), "missing 'username' claim"
+        if v.get("tenant"):
+            assert v.get("tenant_name"), "missing 'tenant_name' claim"
+        return v
+
+    @property
+    def user(self) -> User:
+        return User(id=self.claims["sub"], name=self.claims.get("username"))
+
+    @property
+    def scope(self) -> Scope:
+        return frozenset(self.claims["scope"].split(" "))
+
+    @property
+    def tenant(self) -> Optional[Tenant]:
+        if self.claims.get("tenant"):
+            return Tenant(id=self.claims["tenant"], name=self.claims.get("tenant_name"))
+        else:
+            return None

+ 44 - 33
clean_python/oauth2/oauth2.py → clean_python/oauth2/token_verifier.py

@@ -1,5 +1,6 @@
 # (c) Nelen & Schuurmans
 # (c) Nelen & Schuurmans
 
 
+import logging
 from typing import Dict
 from typing import Dict
 from typing import FrozenSet
 from typing import FrozenSet
 from typing import List
 from typing import List
@@ -10,15 +11,23 @@ from jwt import PyJWKClient
 from jwt.exceptions import PyJWTError
 from jwt.exceptions import PyJWTError
 from pydantic import AnyHttpUrl
 from pydantic import AnyHttpUrl
 from pydantic import BaseModel
 from pydantic import BaseModel
+from pydantic import ValidationError
 
 
 from clean_python import PermissionDenied
 from clean_python import PermissionDenied
 from clean_python import Unauthorized
 from clean_python import Unauthorized
+from clean_python import User
 
 
-from .claims import Claims
-from .claims import Tenant
-from .claims import User
+from .token import Token
 
 
-__all__ = ["TokenVerifier", "TokenVerifierSettings", "OAuth2SPAClientSettings"]
+__all__ = [
+    "BaseTokenVerifier",
+    "TokenVerifier",
+    "NoAuthTokenVerifier",
+    "TokenVerifierSettings",
+    "OAuth2SPAClientSettings",
+]
+
+logger = logging.getLogger(__name__)
 
 
 
 
 class TokenVerifierSettings(BaseModel):
 class TokenVerifierSettings(BaseModel):
@@ -35,7 +44,17 @@ class OAuth2SPAClientSettings(BaseModel):
     authorization_url: AnyHttpUrl
     authorization_url: AnyHttpUrl
 
 
 
 
-class TokenVerifier:
+class BaseTokenVerifier:
+    def __call__(self, authorization: Optional[str]) -> Token:
+        raise NotImplementedError()
+
+
+class NoAuthTokenVerifier(BaseTokenVerifier):
+    def __call__(self, authorization: Optional[str]) -> Token:
+        return Token(claims={"sub": "DEV", "username": "dev", "scope": "superuser"})
+
+
+class TokenVerifier(BaseTokenVerifier):
     """A class for verifying OAuth2 Access Tokens from AWS Cognito
     """A class for verifying OAuth2 Access Tokens from AWS Cognito
 
 
     The verification steps followed are documented here:
     The verification steps followed are documented here:
@@ -47,26 +66,30 @@ class TokenVerifier:
     # allow 2 minutes leeway for verifying token expiry:
     # allow 2 minutes leeway for verifying token expiry:
     LEEWAY = 120
     LEEWAY = 120
 
 
-    def __init__(self, settings: TokenVerifierSettings):
+    def __init__(
+        self, settings: TokenVerifierSettings, logger: Optional[logging.Logger] = None
+    ):
         self.settings = settings
         self.settings = settings
         self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
         self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
 
 
-    def __call__(self, authorization: Optional[str]) -> Claims:
+    def __call__(self, authorization: Optional[str]) -> Token:
         # Step 0: retrieve the token from the Authorization header
         # Step 0: retrieve the token from the Authorization header
         # See https://tools.ietf.org/html/rfc6750#section-2.1,
         # See https://tools.ietf.org/html/rfc6750#section-2.1,
         # Bearer is case-sensitive and there is exactly 1 separator after.
         # Bearer is case-sensitive and there is exactly 1 separator after.
         if authorization is None:
         if authorization is None:
+            logger.info("Missing Authorization header")
             raise Unauthorized()
             raise Unauthorized()
         token = authorization[7:] if authorization.startswith("Bearer") else None
         token = authorization[7:] if authorization.startswith("Bearer") else None
         if token is None:
         if token is None:
+            logger.info("Authorization does not start with 'Bearer '")
             raise Unauthorized()
             raise Unauthorized()
 
 
         # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
         # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
         # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
         # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
         try:
         try:
             key = self.get_key(token)  # JSON Web Key
             key = self.get_key(token)  # JSON Web Key
-        except PyJWTError:
-            # logger.info("Token is invalid: %s", e)
+        except PyJWTError as e:
+            logger.info("Token is invalid: %s", e)
             raise Unauthorized()
             raise Unauthorized()
         # Step 2: Validate the JWT signature and standard claims
         # Step 2: Validate the JWT signature and standard claims
         try:
         try:
@@ -80,19 +103,21 @@ class TokenVerifier:
                     "require": ["exp", "iss", "sub", "scope", "token_use"],
                     "require": ["exp", "iss", "sub", "scope", "token_use"],
                 },
                 },
             )
             )
-        except PyJWTError:
-            # logger.info("Token is invalid: %s", e)
+        except PyJWTError as e:
+            logger.info("Token is invalid: %s", e)
             raise Unauthorized()
             raise Unauthorized()
         # Step 3: Verify additional claims. At this point, we have passed
         # Step 3: Verify additional claims. At this point, we have passed
         # verification, so unverified claims may be used safely.
         # verification, so unverified claims may be used safely.
-        user = self.parse_user(claims)
-        tenant = self.parse_tenant(claims)
-        scope = self.parse_scope(claims)
         self.verify_token_use(claims)
         self.verify_token_use(claims)
-        self.verify_scope(scope)
+        try:
+            token = Token(claims=claims)
+        except ValidationError as e:
+            logger.info("Token is invalid: %s", e)
+            raise Unauthorized()
+        self.verify_scope(token.scope)
         # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
         # Step 4: Authorization: verify user id ('sub' claim) against 'admin_users'
-        self.authorize_user(user)
-        return Claims(user=user, scope=scope, tenant=tenant)
+        self.authorize_user(token.user)
+        return token
 
 
     def get_key(self, token) -> jwt.PyJWK:
     def get_key(self, token) -> jwt.PyJWK:
         """Return the JSON Web KEY (JWK) corresponding to kid."""
         """Return the JSON Web KEY (JWK) corresponding to kid."""
@@ -101,7 +126,7 @@ class TokenVerifier:
     def verify_token_use(self, claims: Dict) -> None:
     def verify_token_use(self, claims: Dict) -> None:
         """Check the token_use claim."""
         """Check the token_use claim."""
         if claims["token_use"] != "access":
         if claims["token_use"] != "access":
-            # logger.info("Token has invalid token_use claim: %s", claims["token_use"])
+            logger.info("Token has invalid token_use claim: %s", claims["token_use"])
             raise Unauthorized()
             raise Unauthorized()
 
 
     def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
     def verify_scope(self, claims_scope: FrozenSet[str]) -> None:
@@ -109,25 +134,11 @@ class TokenVerifier:
         if self.settings.scope is None:
         if self.settings.scope is None:
             return
             return
         if self.settings.scope not in claims_scope:
         if self.settings.scope not in claims_scope:
-            # logger.info("Token has invalid scope claim: %s", claims["scope"])
+            logger.info("Token is missing '%s' scope", self.settings.scope)
             raise Unauthorized()
             raise Unauthorized()
 
 
     def authorize_user(self, user: User) -> None:
     def authorize_user(self, user: User) -> None:
         if self.settings.admin_users is None:
         if self.settings.admin_users is None:
             return
             return
         if user.id not in self.settings.admin_users:
         if user.id not in self.settings.admin_users:
-            # logger.info("User with sub %s is not authorized", claims.get("sub"))
             raise PermissionDenied()
             raise PermissionDenied()
-
-    def parse_user(self, claims: Dict) -> User:
-        return User(id=claims["sub"], name=claims.get("username"))
-
-    def parse_scope(self, claims: Dict) -> FrozenSet[str]:
-        return frozenset(claims["scope"].split(" "))
-
-    def parse_tenant(self, claims: Dict) -> Optional[Tenant]:
-        if claims.get("tenant"):
-            tenant = Tenant(id=claims["tenant"], name=claims.get("tenant_name", ""))
-        else:
-            tenant = None
-        return tenant

+ 44 - 13
clean_python/sql/sql_gateway.py

@@ -8,6 +8,7 @@ from typing import Optional
 from typing import TypeVar
 from typing import TypeVar
 
 
 import inject
 import inject
+from sqlalchemy import and_
 from sqlalchemy import asc
 from sqlalchemy import asc
 from sqlalchemy import delete
 from sqlalchemy import delete
 from sqlalchemy import desc
 from sqlalchemy import desc
@@ -24,6 +25,7 @@ from sqlalchemy.sql.expression import false
 
 
 from clean_python import AlreadyExists
 from clean_python import AlreadyExists
 from clean_python import Conflict
 from clean_python import Conflict
+from clean_python import ctx
 from clean_python import DoesNotExist
 from clean_python import DoesNotExist
 from clean_python import Filter
 from clean_python import Filter
 from clean_python import Gateway
 from clean_python import Gateway
@@ -50,9 +52,12 @@ T = TypeVar("T", bound="SQLGateway")
 class SQLGateway(Gateway):
 class SQLGateway(Gateway):
     table: Table
     table: Table
     nested: bool
     nested: bool
+    multitenant: bool
 
 
     def __init__(
     def __init__(
-        self, provider_override: Optional[SQLProvider] = None, nested: bool = False
+        self,
+        provider_override: Optional[SQLProvider] = None,
+        nested: bool = False,
     ):
     ):
         self.provider_override = provider_override
         self.provider_override = provider_override
         self.nested = nested
         self.nested = nested
@@ -61,8 +66,11 @@ class SQLGateway(Gateway):
     def provider(self):
     def provider(self):
         return self.provider_override or inject.instance(SQLDatabase)
         return self.provider_override or inject.instance(SQLDatabase)
 
 
-    def __init_subclass__(cls, table: Table) -> None:
+    def __init_subclass__(cls, table: Table, multitenant: bool = False) -> None:
         cls.table = table
         cls.table = table
+        if multitenant and not hasattr(table.c, "tenant"):
+            raise ValueError("Can't use a multitenant SQLGateway without tenant column")
+        cls.multitenant = multitenant
         super().__init_subclass__()
         super().__init_subclass__()
 
 
     def rows_to_dict(self, rows: List[Json]) -> List[Json]:
     def rows_to_dict(self, rows: List[Json]) -> List[Json]:
@@ -73,6 +81,8 @@ class SQLGateway(Gateway):
         result = {k: obj[k] for k in obj.keys() if k in known}
         result = {k: obj[k] for k in obj.keys() if k in known}
         if "id" in result and result["id"] is None:
         if "id" in result and result["id"] is None:
             del result["id"]
             del result["id"]
+        if self.multitenant:
+            result["tenant"] = self.current_tenant
         return result
         return result
 
 
     @asynccontextmanager
     @asynccontextmanager
@@ -83,6 +93,14 @@ class SQLGateway(Gateway):
             async with self.provider.transaction() as provider:
             async with self.provider.transaction() as provider:
                 yield self.__class__(provider, nested=True)
                 yield self.__class__(provider, nested=True)
 
 
+    @property
+    def current_tenant(self) -> Optional[int]:
+        if not self.multitenant:
+            return None
+        if ctx.tenant is None:
+            raise RuntimeError(f"{self.__class__} requires a tenant in the context")
+        return ctx.tenant.id
+
     async def get_related(self, items: List[Json]) -> None:
     async def get_related(self, items: List[Json]) -> None:
         pass
         pass
 
 
@@ -114,7 +132,7 @@ class SQLGateway(Gateway):
         id_ = item.get("id")
         id_ = item.get("id")
         if id_ is None:
         if id_ is None:
             raise DoesNotExist("record", id_)
             raise DoesNotExist("record", id_)
-        q = self.table.c.id == id_
+        q = self._id_filter_to_sql(id_)
         if if_unmodified_since is not None:
         if if_unmodified_since is not None:
             q &= self.table.c.updated_at == if_unmodified_since
             q &= self.table.c.updated_at == if_unmodified_since
         query = (
         query = (
@@ -137,7 +155,7 @@ class SQLGateway(Gateway):
     async def _select_for_update(self, id: int) -> Json:
     async def _select_for_update(self, id: int) -> Json:
         async with self.transaction() as transaction:
         async with self.transaction() as transaction:
             result = await transaction.execute(
             result = await transaction.execute(
-                select(self.table).with_for_update().where(self.table.c.id == id),
+                select(self.table).with_for_update().where(self._id_filter_to_sql(id)),
             )
             )
             if not result:
             if not result:
                 raise DoesNotExist("record", id)
                 raise DoesNotExist("record", id)
@@ -157,7 +175,10 @@ class SQLGateway(Gateway):
         query = (
         query = (
             insert(self.table)
             insert(self.table)
             .values(**values)
             .values(**values)
-            .on_conflict_do_update(index_elements=["id"], set_=values)
+            .on_conflict_do_update(
+                index_elements=["id", "tenant"] if self.multitenant else ["id"],
+                set_=values,
+            )
             .returning(self.table)
             .returning(self.table)
         )
         )
         async with self.transaction() as transaction:
         async with self.transaction() as transaction:
@@ -167,13 +188,15 @@ class SQLGateway(Gateway):
 
 
     async def remove(self, id) -> bool:
     async def remove(self, id) -> bool:
         query = (
         query = (
-            delete(self.table).where(self.table.c.id == id).returning(self.table.c.id)
+            delete(self.table)
+            .where(self._id_filter_to_sql(id))
+            .returning(self.table.c.id)
         )
         )
         async with self.transaction() as transaction:
         async with self.transaction() as transaction:
             result = await transaction.execute(query)
             result = await transaction.execute(query)
         return bool(result)
         return bool(result)
 
 
-    def _to_sqlalchemy_expression(self, filter: Filter) -> ColumnElement:
+    def _filter_to_sql(self, filter: Filter) -> ColumnElement:
         try:
         try:
             column = getattr(self.table.c, filter.field)
             column = getattr(self.table.c, filter.field)
         except AttributeError:
         except AttributeError:
@@ -185,12 +208,19 @@ class SQLGateway(Gateway):
         else:
         else:
             return column.in_(filter.values)
             return column.in_(filter.values)
 
 
+    def _filters_to_sql(self, filters: List[Filter]) -> ColumnElement:
+        qs = [self._filter_to_sql(x) for x in filters]
+        if self.multitenant:
+            qs.append(self.table.c.tenant == self.current_tenant)
+        return and_(*qs)
+
+    def _id_filter_to_sql(self, id: int) -> ColumnElement:
+        return self._filters_to_sql([Filter(field="id", values=[id])])
+
     async def filter(
     async def filter(
         self, filters: List[Filter], params: Optional[PageOptions] = None
         self, filters: List[Filter], params: Optional[PageOptions] = None
     ) -> List[Json]:
     ) -> List[Json]:
-        query = select(self.table).where(
-            *[self._to_sqlalchemy_expression(x) for x in filters]
-        )
+        query = select(self.table).where(self._filters_to_sql(filters))
         if params is not None:
         if params is not None:
             sort = asc(params.order_by) if params.ascending else desc(params.order_by)
             sort = asc(params.order_by) if params.ascending else desc(params.order_by)
             query = query.order_by(sort).limit(params.limit).offset(params.offset)
             query = query.order_by(sort).limit(params.limit).offset(params.offset)
@@ -203,7 +233,7 @@ class SQLGateway(Gateway):
         query = (
         query = (
             select(func.count().label("count"))
             select(func.count().label("count"))
             .select_from(self.table)
             .select_from(self.table)
-            .where(*[self._to_sqlalchemy_expression(x) for x in filters])
+            .where(self._filters_to_sql(filters))
         )
         )
         async with self.transaction() as transaction:
         async with self.transaction() as transaction:
             return (await transaction.execute(query))[0]["count"]
             return (await transaction.execute(query))[0]["count"]
@@ -212,7 +242,7 @@ class SQLGateway(Gateway):
         query = (
         query = (
             select(true().label("exists"))
             select(true().label("exists"))
             .select_from(self.table)
             .select_from(self.table)
-            .where(*[self._to_sqlalchemy_expression(x) for x in filters])
+            .where(self._filters_to_sql(filters))
             .limit(1)
             .limit(1)
         )
         )
         async with self.transaction() as transaction:
         async with self.transaction() as transaction:
@@ -257,6 +287,7 @@ class SQLGateway(Gateway):
                 ]
                 ]
             }
             }
         """
         """
+        assert not self.multitenant
         for x in items:
         for x in items:
             x[field_name] = []
             x[field_name] = []
         item_lut = {x["id"]: x for x in items}
         item_lut = {x["id"]: x for x in items}
@@ -308,7 +339,7 @@ class SQLGateway(Gateway):
                 ]
                 ]
             }
             }
         """
         """
-
+        assert not self.multitenant
         # list existing related objects
         # list existing related objects
         existing_lut = {
         existing_lut = {
             x["id"]: x
             x["id"]: x

+ 0 - 0
tests/test_fastapi_access_logger.py → tests/fastapi/test_fastapi_access_logger.py


+ 0 - 0
tests/test_request_query.py → tests/fastapi/test_request_query.py


+ 20 - 0
tests/test_resource.py → tests/fastapi/test_resource.py

@@ -3,6 +3,7 @@ from fastapi.routing import APIRouter
 
 
 from clean_python.fastapi import APIVersion
 from clean_python.fastapi import APIVersion
 from clean_python.fastapi import get
 from clean_python.fastapi import get
+from clean_python.fastapi import RequiresScope
 from clean_python.fastapi import Resource
 from clean_python.fastapi import Resource
 from clean_python.fastapi import Stability
 from clean_python.fastapi import Stability
 from clean_python.fastapi import v
 from clean_python.fastapi import v
@@ -54,6 +55,7 @@ def test_get_router():
     assert route.name == "v1/get_test"
     assert route.name == "v1/get_test"
     assert route.tags == ["testing"]
     assert route.tags == ["testing"]
     assert route.methods == {"GET"}
     assert route.methods == {"GET"}
+    assert len(route.dependencies) == 0
     # 'self' is missing from the parameters
     # 'self' is missing from the parameters
     assert list(route.param_convertors.keys()) == ["id"]
     assert list(route.param_convertors.keys()) == ["id"]
 
 
@@ -134,3 +136,21 @@ def test_get_less_stable_no_subclass():
 
 
     with pytest.raises(RuntimeError):
     with pytest.raises(RuntimeError):
         resources[v(1)].get_less_stable(resources)
         resources[v(1)].get_less_stable(resources)
+
+
+def test_get_router_with_scope():
+    class TestResource(Resource, version=v(1), name="testing"):
+        @get("/foo/{id}", scope="foo")
+        def get_test(self, id: int):
+            return "ok"
+
+    resource = TestResource()
+
+    router = resource.get_router(v(1))
+
+    assert len(router.routes) == 1
+
+    route = router.routes[0]
+    (dep,) = route.dependencies
+    assert isinstance(dep.dependency, RequiresScope)
+    assert dep.dependency.scope == "foo"

+ 38 - 0
tests/fastapi/test_security.py

@@ -0,0 +1,38 @@
+import pytest
+
+from clean_python import PermissionDenied
+from clean_python.fastapi import RequiresScope
+from clean_python.oauth2 import Token
+
+
+@pytest.fixture
+def token() -> Token:
+    return Token(claims={"sub": "abc123", "scope": "a b", "username": "foo"})
+
+
+@pytest.fixture
+def token_multitenant() -> Token:
+    return Token(
+        claims={
+            "sub": "abc123",
+            "scope": "a b",
+            "username": "foo",
+            "tenant": 1,
+            "tenant_name": "bar",
+        }
+    )
+
+
+@pytest.mark.parametrize("scope", ["a", "b"])
+async def test_requires_scope(token, scope):
+    await RequiresScope(scope)(token)
+
+
+async def test_requires_scope_err(token):
+    with pytest.raises(PermissionDenied):
+        await RequiresScope("c")(token)
+
+
+def test_requries_scope_no_spaces_allowed():
+    with pytest.raises(AssertionError):
+        RequiresScope("c ")

+ 0 - 0
tests/test_service.py → tests/fastapi/test_service.py


+ 1 - 0
tests/oauth2/conftest.py

@@ -44,6 +44,7 @@ def jwk_patched(public_key):
 def token_generator(private_key):
 def token_generator(private_key):
     default_claims = {
     default_claims = {
         "sub": "foo",
         "sub": "foo",
+        "username": "piet",
         "iss": "https://some/auth/server",
         "iss": "https://some/auth/server",
         "scope": "user",
         "scope": "user",
         "token_use": "access",
         "token_use": "access",

+ 49 - 26
tests/oauth2/test_service_auth.py

@@ -3,6 +3,7 @@ from http import HTTPStatus
 import pytest
 import pytest
 from fastapi.testclient import TestClient
 from fastapi.testclient import TestClient
 
 
+from clean_python import ctx
 from clean_python import InMemoryGateway
 from clean_python import InMemoryGateway
 from clean_python.fastapi import get
 from clean_python.fastapi import get
 from clean_python.fastapi import Resource
 from clean_python.fastapi import Resource
@@ -17,14 +18,35 @@ class FooResource(Resource, version=v(1), name="testing"):
     def testing(self):
     def testing(self):
         return "ok"
         return "ok"
 
 
+    @get("/bar", scope="admin")
+    def scoped(self):
+        return "ok"
 
 
-@pytest.fixture
-def app(settings: TokenVerifierSettings):
+    @get("/context")
+    def context(self):
+        return {
+            "path": str(ctx.path),
+            "user": ctx.user,
+            "tenant": ctx.tenant,
+        }
+
+
+@pytest.fixture(params=["noclient", "client"])
+def app(request, settings: TokenVerifierSettings):
+    if request.param == "noclient":
+        auth_client = None
+    elif request.param == "client":
+        auth_client = OAuth2SPAClientSettings(
+            client_id="123",
+            token_url="https://server/token",
+            authorization_url="https://server/token",
+        )
     return Service(FooResource()).create_app(
     return Service(FooResource()).create_app(
         title="test",
         title="test",
         description="testing",
         description="testing",
         hostname="testserver",
         hostname="testserver",
         auth=settings,
         auth=settings,
+        auth_client=auth_client,
         access_logger_gateway=InMemoryGateway([]),
         access_logger_gateway=InMemoryGateway([]),
     )
     )
 
 
@@ -51,39 +73,40 @@ def test_ok(app, client: TestClient, token_generator):
     assert response.status_code == HTTPStatus.OK
     assert response.status_code == HTTPStatus.OK
 
 
 
 
-@pytest.fixture
-def app2(settings: TokenVerifierSettings):
-    return Service(FooResource()).create_app(
-        title="test",
-        description="testing",
-        hostname="testserver",
-        auth=settings,
-        auth_client=OAuth2SPAClientSettings(
-            client_id="123",
-            token_url="https://server/token",
-            authorization_url="https://server/token",
-        ),
-        access_logger_gateway=InMemoryGateway([]),
+@pytest.mark.usefixtures("jwk_patched")
+def test_scoped_ok(app, client: TestClient, token_generator):
+    response = client.get(
+        app.url_path_for("v1/scoped"),
+        headers={"Authorization": "Bearer " + token_generator(scope="user admin")},
     )
     )
 
 
-
-@pytest.fixture
-def client2(app):
-    return TestClient(app)
+    assert response.status_code == HTTPStatus.OK
 
 
 
 
 @pytest.mark.usefixtures("jwk_patched")
 @pytest.mark.usefixtures("jwk_patched")
-def test_no_header2(app2, client2: TestClient):
-    response = client2.get(app2.url_path_for("v1/testing"))
+def test_scoped_forbidden(app, client: TestClient, token_generator):
+    response = client.get(
+        app.url_path_for("v1/scoped"),
+        headers={"Authorization": "Bearer " + token_generator(scope="user")},
+    )
 
 
-    assert response.status_code == HTTPStatus.UNAUTHORIZED
+    assert response.status_code == HTTPStatus.FORBIDDEN
 
 
 
 
 @pytest.mark.usefixtures("jwk_patched")
 @pytest.mark.usefixtures("jwk_patched")
-def test_ok2(app2, client2: TestClient, token_generator):
-    response = client2.get(
-        app2.url_path_for("v1/testing"),
-        headers={"Authorization": "Bearer " + token_generator()},
+def test_context(app, client: TestClient, token_generator):
+    response = client.get(
+        app.url_path_for("v1/context"),
+        headers={
+            "Authorization": "Bearer " + token_generator(tenant=2, tenant_name="bar")
+        },
     )
     )
 
 
     assert response.status_code == HTTPStatus.OK
     assert response.status_code == HTTPStatus.OK
+    assert response.json() == {
+        "path": "http://testserver/v1/context",
+        "user": {"id": "foo", "name": "piet"},
+        "tenant": {"id": 2, "name": "bar"},
+    }
+    assert ctx.user.id != "foo"
+    assert ctx.tenant is None

+ 64 - 0
tests/oauth2/test_token.py

@@ -0,0 +1,64 @@
+import pytest
+from pydantic import ValidationError
+
+from clean_python import Tenant
+from clean_python import User
+from clean_python.oauth2 import Token
+
+
+@pytest.fixture
+def claims():
+    return {"sub": "abc123", "scope": "a b", "username": "foo"}
+
+
+@pytest.fixture
+def claims_multitenant():
+    return {
+        "sub": "abc123",
+        "scope": "a b",
+        "username": "foo",
+        "tenant": 1,
+        "tenant_name": "bar",
+    }
+
+
+def test_init(claims):
+    Token(claims=claims)
+
+
+def test_init_multitenant(claims_multitenant):
+    Token(claims=claims_multitenant)
+
+
+@pytest.mark.parametrize(
+    "claims",
+    [
+        {"scope": "", "username": "foo"},
+        {"sub": "abc123", "username": "foo"},
+        {"sub": "abc123", "scope": ""},
+        {"sub": "abc123", "scope": "", "username": "foo", "tenant": 1},
+    ],
+)
+def test_init_err(claims):
+    with pytest.raises(ValidationError):
+        Token(claims=claims)
+
+
+def test_user(claims):
+    actual = Token(claims=claims).user
+    assert actual == User(id="abc123", name="foo")
+
+
+def test_scope(claims):
+    actual = Token(claims=claims).scope
+    assert actual == frozenset({"a", "b"})
+
+
+def test_tenant(claims_multitenant):
+    actual = Token(claims=claims_multitenant).tenant
+    assert actual == Tenant(id=1, name="bar")
+
+
+def test_no_tenant(claims):
+    actual = Token(claims=claims).tenant
+    assert actual is None

+ 8 - 24
tests/oauth2/test_oauth2.py → tests/oauth2/test_verifier.py

@@ -6,6 +6,7 @@ import pytest
 
 
 from clean_python import PermissionDenied
 from clean_python import PermissionDenied
 from clean_python import Unauthorized
 from clean_python import Unauthorized
+from clean_python.oauth2 import Token
 from clean_python.oauth2 import TokenVerifier
 from clean_python.oauth2 import TokenVerifier
 
 
 
 
@@ -16,32 +17,14 @@ def patched_verifier(jwk_patched, settings):
 
 
 def test_verifier_ok(patched_verifier, token_generator):
 def test_verifier_ok(patched_verifier, token_generator):
     token = token_generator()
     token = token_generator()
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.user.id == "foo"
-    assert verified_claims.tenant is None
-    assert verified_claims.scope == {"user"}
-
-    patched_verifier.get_key.assert_called_once_with(token)
-
-
-def test_verifier_ok_with_username(patched_verifier, token_generator):
-    token = token_generator(username="sinterklaas")
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.user.name == "sinterklaas"
+    verified_token = patched_verifier("Bearer " + token)
 
 
+    assert isinstance(verified_token, Token)
+    assert verified_token.user.id == "foo"
+    assert verified_token.tenant is None
+    assert verified_token.scope == {"user"}
 
 
-def test_verifier_ok_with_tenant(patched_verifier, token_generator):
-    token = token_generator(tenant="15")
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.tenant.id == 15
-    assert verified_claims.tenant.name == ""
-
-
-def test_verifier_ok_with_tenant_and_name(patched_verifier, token_generator):
-    token = token_generator(tenant=15, tenant_name="foo")
-    verified_claims = patched_verifier("Bearer " + token)
-    assert verified_claims.tenant.id == 15
-    assert verified_claims.tenant.name == "foo"
+    patched_verifier.get_key.assert_called_once_with(token)
 
 
 
 
 def test_verifier_exp_leeway(patched_verifier, token_generator):
 def test_verifier_exp_leeway(patched_verifier, token_generator):
@@ -67,6 +50,7 @@ def test_verifier_multiple_scopes(patched_verifier, token_generator, settings):
         {"token_use": "id"},
         {"token_use": "id"},
         {"token_use": None},
         {"token_use": None},
         {"sub": None},
         {"sub": None},
+        {"username": None},
     ],
     ],
 )
 )
 def test_verifier_bad(patched_verifier, token_generator, claim_overrides):
 def test_verifier_bad(patched_verifier, token_generator, claim_overrides):

+ 0 - 0
tests/test_sql_gateway.py → tests/sql/test_sql_gateway.py


+ 179 - 0
tests/sql/test_sql_gateway_multitenant.py

@@ -0,0 +1,179 @@
+import pytest
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+from sqlalchemy import Text
+
+from clean_python import ctx
+from clean_python import Filter
+from clean_python import Tenant
+from clean_python.sql import SQLGateway
+from clean_python.sql.testing import assert_query_equal
+from clean_python.sql.testing import FakeSQLDatabase
+
+writer = Table(
+    "writer",
+    MetaData(),
+    Column("id", Integer, primary_key=True, autoincrement=True),
+    Column("value", Text, nullable=False),
+    Column("updated_at", DateTime(timezone=True), nullable=False),
+    Column("tenant", Integer, nullable=False),
+)
+
+
+ALL_FIELDS = "writer.id, writer.value, writer.updated_at, writer.tenant"
+
+
+class TstSQLGateway(SQLGateway, table=writer, multitenant=True):
+    pass
+
+
+@pytest.fixture
+def sql_gateway():
+    return TstSQLGateway(FakeSQLDatabase())
+
+
+@pytest.fixture
+def tenant():
+    ctx.tenant = Tenant(id=2, name="foo")
+    return ctx.tenant
+
+
+async def test_no_tenant(sql_gateway):
+    with pytest.raises(RuntimeError):
+        await sql_gateway.filter([])
+    assert len(sql_gateway.provider.queries) == 0
+
+
+async def test_missing_tenant_column():
+    table = Table(
+        "notenant",
+        MetaData(),
+        Column("id", Integer, primary_key=True, autoincrement=True),
+    )
+
+    with pytest.raises(ValueError):
+
+        class Foo(SQLGateway, table=table, multitenant=True):
+            pass
+
+
+async def test_filter(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
+    assert await sql_gateway.filter([Filter(field="id", values=[1])]) == [
+        {"id": 2, "value": "foo"}
+    ]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
+    )
+
+
+async def test_count(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"count": 4}]
+    assert await sql_gateway.count([Filter(field="id", values=[1])]) == 4
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT count(*) AS count FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
+    )
+
+
+async def test_add(sql_gateway, tenant):
+    records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
+    sql_gateway.provider.result.return_value = records
+    assert await sql_gateway.add({"value": "foo"}) == records[0]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"INSERT INTO writer (value, tenant) VALUES ('foo', {tenant.id}) RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_update(sql_gateway, tenant):
+    records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
+    sql_gateway.provider.result.return_value = records
+    assert await sql_gateway.update({"id": 2, "value": "foo"}) == records[0]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"UPDATE writer SET id=2, value='foo', tenant={tenant.id} "
+            f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
+            f"RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_remove(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"id": 2}]
+    assert (await sql_gateway.remove(2)) is True
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"DELETE FROM writer WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
+            f"RETURNING writer.id"
+        ),
+    )
+
+
+async def test_upsert(sql_gateway, tenant):
+    record = {"id": 2, "value": "foo", "tenant": tenant.id}
+    sql_gateway.provider.result.return_value = [record]
+    assert await sql_gateway.upsert({"id": 2, "value": "foo"}) == record
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"INSERT INTO writer (id, value, tenant) VALUES (2, 'foo', {tenant.id}) "
+            f"ON CONFLICT (id, tenant) DO UPDATE SET "
+            f"id = %(param_1)s, value = %(param_2)s, tenant = %(param_3)s "
+            f"RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_update_transactional(sql_gateway, tenant):
+    existing = {"id": 2, "value": "foo", "tenant": tenant.id}
+    expected = {"id": 2, "value": "bar", "tenant": tenant.id}
+    sql_gateway.provider.result.side_effect = ([existing], [expected])
+    actual = await sql_gateway.update_transactional(
+        2, lambda x: {"id": x["id"], "value": "bar"}
+    )
+    assert actual == expected
+
+    (queries,) = sql_gateway.provider.queries
+    assert len(queries) == 2
+    assert_query_equal(
+        queries[0],
+        (
+            f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 "
+            f"AND writer.tenant = {tenant.id} FOR UPDATE"
+        ),
+    )
+    assert_query_equal(
+        queries[1],
+        (
+            f"UPDATE writer SET id=2, value='bar', tenant={tenant.id} "
+            f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+async def test_exists(sql_gateway, tenant):
+    sql_gateway.provider.result.return_value = [{"exists": True}]
+    assert await sql_gateway.exists([Filter(field="id", values=[1])]) is True
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"SELECT true AS exists FROM writer "
+            f"WHERE writer.id = 1 AND writer.tenant = {tenant.id} LIMIT 1"
+        ),
+    )

+ 40 - 0
tests/test_context.py

@@ -0,0 +1,40 @@
+import asyncio
+import os
+
+from pydantic import HttpUrl
+
+from clean_python import ctx
+from clean_python import Tenant
+from clean_python import User
+
+
+def test_default_context():
+    assert str(ctx.path) == "file://" + os.getcwd()
+    assert ctx.user.id == "ANONYMOUS"
+    assert ctx.user.name == "anonymous"
+    assert ctx.tenant is None
+
+
+async def test_task_isolation():
+    async def get_set(user):
+        ctx.user = user
+        asyncio.sleep(0.01)
+        assert ctx.user == user
+
+    await asyncio.gather(*[get_set(User(id=str(i), name="piet")) for i in range(10)])
+    assert ctx.user.id == "ANONYMOUS"
+
+
+async def test_tenant():
+    tenant = Tenant(id=2, name="foo")
+    ctx.tenant = tenant
+    assert ctx.tenant == tenant
+
+    ctx.tenant = None
+    assert ctx.tenant is None
+
+
+async def test_path():
+    url = HttpUrl("http://testserver/foo?a=b")
+    ctx.path = url
+    assert ctx.path == url