123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- # (c) Nelen & Schuurmans
- from typing import Any
- from typing import Callable
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Set
- from fastapi import Depends
- from fastapi import FastAPI
- from fastapi import Request
- from fastapi.exceptions import RequestValidationError
- from starlette.types import ASGIApp
- from clean_python import BadRequest
- from clean_python import Conflict
- from clean_python import ctx
- from clean_python import DoesNotExist
- from clean_python import Gateway
- from clean_python import PermissionDenied
- from clean_python import Unauthorized
- from clean_python.oauth2 import OAuth2SPAClientSettings
- from clean_python.oauth2 import Token
- from clean_python.oauth2 import TokenVerifierSettings
- from .error_responses import conflict_handler
- from .error_responses import DefaultErrorResponse
- from .error_responses import not_found_handler
- from .error_responses import permission_denied_handler
- from .error_responses import unauthorized_handler
- from .error_responses import validation_error_handler
- from .error_responses import ValidationErrorResponse
- from .fastapi_access_logger import FastAPIAccessLogger
- from .fastapi_access_logger import get_correlation_id
- from .resource import APIVersion
- from .resource import clean_resources
- from .resource import Resource
- from .security import get_token
- from .security import JWTBearerTokenSchema
- from .security import OAuth2SPAClientSchema
- from .security import set_verifier
- __all__ = ["Service"]
- def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> Dict[str, Any]:
- if auth_client is None:
- return {
- "dependencies": [Depends(JWTBearerTokenSchema()), Depends(set_context)],
- }
- else:
- return {
- "dependencies": [
- Depends(OAuth2SPAClientSchema(client=auth_client)),
- Depends(set_context),
- ],
- "swagger_ui_init_oauth": {
- "clientId": auth_client.client_id,
- "usePkceWithAuthorizationCodeGrant": True,
- },
- }
- async def set_context(
- request: Request,
- token: Token = Depends(get_token),
- ) -> None:
- ctx.path = request.url
- ctx.user = token.user
- ctx.tenant = token.tenant
- ctx.correlation_id = get_correlation_id(request)
- async def health_check():
- """Simple health check route"""
- return {"health": "OK"}
- class Service:
- resources: List[Resource]
- def __init__(self, *args: Resource):
- self.resources = clean_resources(args)
- @property
- def versions(self) -> Set[APIVersion]:
- return set([x.version for x in self.resources])
- def _create_root_app(
- self,
- title: str,
- description: str,
- hostname: str,
- on_startup: Optional[List[Callable[[], Any]]] = None,
- access_logger_gateway: Optional[Gateway] = None,
- ) -> FastAPI:
- app = FastAPI(
- title=title,
- description=description,
- on_startup=on_startup,
- servers=[
- {"url": f"{x.prefix}", "description": x.description}
- for x in self.versions
- ],
- root_path_in_servers=False,
- )
- app.middleware("http")(
- FastAPIAccessLogger(
- hostname=hostname, gateway_override=access_logger_gateway
- )
- )
- app.get("/health", include_in_schema=False)(health_check)
- return app
- def _create_versioned_app(self, version: APIVersion, **kwargs) -> FastAPI:
- resources = [x for x in self.resources if x.version == version]
- app = FastAPI(
- version=version.prefix,
- tags=sorted(
- [x.get_openapi_tag().model_dump() for x in resources],
- key=lambda x: x["name"],
- ),
- **kwargs,
- )
- for resource in resources:
- app.include_router(
- resource.get_router(
- version,
- responses={
- "400": {"model": ValidationErrorResponse},
- "default": {"model": DefaultErrorResponse},
- },
- )
- )
- app.add_exception_handler(DoesNotExist, not_found_handler)
- app.add_exception_handler(Conflict, conflict_handler)
- app.add_exception_handler(RequestValidationError, validation_error_handler)
- app.add_exception_handler(BadRequest, validation_error_handler)
- app.add_exception_handler(PermissionDenied, permission_denied_handler)
- app.add_exception_handler(Unauthorized, unauthorized_handler)
- return app
- def create_app(
- self,
- title: str,
- description: str,
- hostname: str,
- auth: Optional[TokenVerifierSettings] = None,
- auth_client: Optional[OAuth2SPAClientSettings] = None,
- on_startup: Optional[List[Callable[[], Any]]] = None,
- access_logger_gateway: Optional[Gateway] = None,
- ) -> ASGIApp:
- set_verifier(auth)
- app = self._create_root_app(
- title=title,
- description=description,
- hostname=hostname,
- on_startup=on_startup,
- access_logger_gateway=access_logger_gateway,
- )
- kwargs = {
- "title": title,
- "description": description,
- **get_auth_kwargs(auth_client),
- }
- versioned_apps = {
- v: self._create_versioned_app(v, **kwargs) for v in self.versions
- }
- for v, versioned_app in versioned_apps.items():
- app.mount("/" + v.prefix, versioned_app)
- return app
|