# (c) Nelen & Schuurmans import logging from typing import Any from typing import Callable from typing import List from typing import Optional from typing import Set from asgiref.sync import sync_to_async from fastapi import Depends from fastapi import FastAPI from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.security import HTTPBearer from fastapi.security import OAuth2AuthorizationCodeBearer from starlette.types import ASGIApp from clean_python import Conflict from clean_python import DoesNotExist from clean_python import Gateway from clean_python import PermissionDenied from clean_python import Unauthorized from clean_python.oauth2 import OAuth2SPAClientSettings from clean_python.oauth2 import TokenVerifier from clean_python.oauth2 import TokenVerifierSettings from .context import ctx from .context import RequestMiddleware from .error_responses import BadRequest from .error_responses import conflict_handler from .error_responses import DefaultErrorResponse from .error_responses import not_found_handler from .error_responses import not_implemented_handler from .error_responses import permission_denied_handler from .error_responses import unauthorized_handler from .error_responses import validation_error_handler from .error_responses import ValidationErrorResponse from .fastapi_access_logger import FastAPIAccessLogger from .resource import APIVersion from .resource import clean_resources from .resource import Resource logger = logging.getLogger(__name__) __all__ = ["Service"] class OAuth2WithClientDependable(OAuth2AuthorizationCodeBearer): """A fastapi 'dependable' configuring OAuth2. This does two things: - Verify the token in each request - (through FastAPI magic) add the scheme to the OpenAPI spec """ def __init__( self, settings: TokenVerifierSettings, client: OAuth2SPAClientSettings ): self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False) super().__init__( scheme_name="OAuth2 Authorization Code Flow with PKCE", authorizationUrl=str(client.authorization_url), tokenUrl=str(client.token_url), ) async def __call__(self, request: Request) -> None: ctx.claims = await self.verifier(request.headers.get("Authorization")) class OAuth2WithoutClientDependable(HTTPBearer): """A fastapi 'dependable' configuring OAuth2. This does one thing: - Verify the token in each request """ def __init__(self, settings: TokenVerifierSettings): self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False) super().__init__(scheme_name="JWT Bearer token", bearerFormat="JWT") async def __call__(self, request: Request) -> None: ctx.claims = await self.verifier(request.headers.get("Authorization")) def get_auth_kwargs( auth: Optional[TokenVerifierSettings], auth_client: Optional[OAuth2SPAClientSettings], ) -> None: if auth is None: return {} if auth_client is None: return { "dependencies": [Depends(OAuth2WithoutClientDependable(settings=auth))], } else: return { "dependencies": [ Depends(OAuth2WithClientDependable(settings=auth, client=auth_client)) ], "swagger_ui_init_oauth": { "clientId": auth_client.client_id, "usePkceWithAuthorizationCodeGrant": True, }, } async def health_check(): """Simple health check route""" return {"health": "OK"} class Service: resources: List[Resource] def __init__(self, *args: Resource): self.resources = clean_resources(args) @property def versions(self) -> Set[APIVersion]: return set([x.version for x in self.resources]) def _create_root_app( self, title: str, description: str, hostname: str, on_startup: Optional[List[Callable[[], Any]]] = None, access_logger_gateway: Optional[Gateway] = None, ) -> FastAPI: app = FastAPI( title=title, description=description, on_startup=on_startup, servers=[ {"url": f"{x.prefix}", "description": x.description} for x in self.versions ], root_path_in_servers=False, ) app.middleware("http")( FastAPIAccessLogger( hostname=hostname, gateway_override=access_logger_gateway ) ) app.add_middleware(RequestMiddleware) app.get("/health", include_in_schema=False)(health_check) return app def _create_versioned_app(self, version: APIVersion, **kwargs) -> FastAPI: resources = [x for x in self.resources if x.version == version] app = FastAPI( version=version.prefix, tags=sorted( [x.get_openapi_tag().model_dump() for x in resources], key=lambda x: x["name"], ), **kwargs, ) for resource in resources: app.include_router( resource.get_router( version, responses={ "400": {"model": ValidationErrorResponse}, "default": {"model": DefaultErrorResponse}, }, ) ) app.add_exception_handler(DoesNotExist, not_found_handler) app.add_exception_handler(Conflict, conflict_handler) app.add_exception_handler(RequestValidationError, validation_error_handler) app.add_exception_handler(BadRequest, validation_error_handler) app.add_exception_handler(NotImplementedError, not_implemented_handler) app.add_exception_handler(PermissionDenied, permission_denied_handler) app.add_exception_handler(Unauthorized, unauthorized_handler) return app def create_app( self, title: str, description: str, hostname: str, auth: Optional[TokenVerifierSettings] = None, auth_client: Optional[OAuth2SPAClientSettings] = None, on_startup: Optional[List[Callable[[], Any]]] = None, access_logger_gateway: Optional[Gateway] = None, ) -> ASGIApp: app = self._create_root_app( title=title, description=description, hostname=hostname, on_startup=on_startup, access_logger_gateway=access_logger_gateway, ) kwargs = { "title": title, "description": description, **get_auth_kwargs(auth, auth_client), } versioned_apps = { v: self._create_versioned_app(v, **kwargs) for v in self.versions } for v, versioned_app in versioned_apps.items(): app.mount("/" + v.prefix, versioned_app) return app