service.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # (c) Nelen & Schuurmans
  2. from typing import Any
  3. from typing import Callable
  4. from typing import List
  5. from typing import Optional
  6. from typing import Set
  7. from fastapi import Depends
  8. from fastapi import FastAPI
  9. from fastapi import Request
  10. from fastapi.exceptions import RequestValidationError
  11. from starlette.types import ASGIApp
  12. from clean_python import Conflict
  13. from clean_python import ctx
  14. from clean_python import DoesNotExist
  15. from clean_python import Gateway
  16. from clean_python import PermissionDenied
  17. from clean_python import Unauthorized
  18. from clean_python.oauth2 import OAuth2SPAClientSettings
  19. from clean_python.oauth2 import Token
  20. from clean_python.oauth2 import TokenVerifierSettings
  21. from .error_responses import BadRequest
  22. from .error_responses import conflict_handler
  23. from .error_responses import DefaultErrorResponse
  24. from .error_responses import not_found_handler
  25. from .error_responses import not_implemented_handler
  26. from .error_responses import permission_denied_handler
  27. from .error_responses import unauthorized_handler
  28. from .error_responses import validation_error_handler
  29. from .error_responses import ValidationErrorResponse
  30. from .fastapi_access_logger import FastAPIAccessLogger
  31. from .resource import APIVersion
  32. from .resource import clean_resources
  33. from .resource import Resource
  34. from .security import get_token
  35. from .security import JWTBearerTokenSchema
  36. from .security import OAuth2SPAClientSchema
  37. from .security import set_verifier
  38. __all__ = ["Service"]
  39. def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> None:
  40. if auth_client is None:
  41. return {
  42. "dependencies": [Depends(JWTBearerTokenSchema()), Depends(set_context)],
  43. }
  44. else:
  45. return {
  46. "dependencies": [
  47. Depends(OAuth2SPAClientSchema(client=auth_client)),
  48. Depends(set_context),
  49. ],
  50. "swagger_ui_init_oauth": {
  51. "clientId": auth_client.client_id,
  52. "usePkceWithAuthorizationCodeGrant": True,
  53. },
  54. }
  55. async def set_context(request: Request, token: Token = Depends(get_token)) -> None:
  56. ctx.path = request.url
  57. ctx.user = token.user
  58. ctx.tenant = token.tenant
  59. async def health_check():
  60. """Simple health check route"""
  61. return {"health": "OK"}
  62. class Service:
  63. resources: List[Resource]
  64. def __init__(self, *args: Resource):
  65. self.resources = clean_resources(args)
  66. @property
  67. def versions(self) -> Set[APIVersion]:
  68. return set([x.version for x in self.resources])
  69. def _create_root_app(
  70. self,
  71. title: str,
  72. description: str,
  73. hostname: str,
  74. on_startup: Optional[List[Callable[[], Any]]] = None,
  75. access_logger_gateway: Optional[Gateway] = None,
  76. ) -> FastAPI:
  77. app = FastAPI(
  78. title=title,
  79. description=description,
  80. on_startup=on_startup,
  81. servers=[
  82. {"url": f"{x.prefix}", "description": x.description}
  83. for x in self.versions
  84. ],
  85. root_path_in_servers=False,
  86. )
  87. app.middleware("http")(
  88. FastAPIAccessLogger(
  89. hostname=hostname, gateway_override=access_logger_gateway
  90. )
  91. )
  92. app.get("/health", include_in_schema=False)(health_check)
  93. return app
  94. def _create_versioned_app(self, version: APIVersion, **kwargs) -> FastAPI:
  95. resources = [x for x in self.resources if x.version == version]
  96. app = FastAPI(
  97. version=version.prefix,
  98. tags=sorted(
  99. [x.get_openapi_tag().model_dump() for x in resources],
  100. key=lambda x: x["name"],
  101. ),
  102. **kwargs,
  103. )
  104. for resource in resources:
  105. app.include_router(
  106. resource.get_router(
  107. version,
  108. responses={
  109. "400": {"model": ValidationErrorResponse},
  110. "default": {"model": DefaultErrorResponse},
  111. },
  112. )
  113. )
  114. app.add_exception_handler(DoesNotExist, not_found_handler)
  115. app.add_exception_handler(Conflict, conflict_handler)
  116. app.add_exception_handler(RequestValidationError, validation_error_handler)
  117. app.add_exception_handler(BadRequest, validation_error_handler)
  118. app.add_exception_handler(NotImplementedError, not_implemented_handler)
  119. app.add_exception_handler(PermissionDenied, permission_denied_handler)
  120. app.add_exception_handler(Unauthorized, unauthorized_handler)
  121. return app
  122. def create_app(
  123. self,
  124. title: str,
  125. description: str,
  126. hostname: str,
  127. auth: Optional[TokenVerifierSettings] = None,
  128. auth_client: Optional[OAuth2SPAClientSettings] = None,
  129. on_startup: Optional[List[Callable[[], Any]]] = None,
  130. access_logger_gateway: Optional[Gateway] = None,
  131. ) -> ASGIApp:
  132. set_verifier(auth)
  133. app = self._create_root_app(
  134. title=title,
  135. description=description,
  136. hostname=hostname,
  137. on_startup=on_startup,
  138. access_logger_gateway=access_logger_gateway,
  139. )
  140. kwargs = {
  141. "title": title,
  142. "description": description,
  143. **get_auth_kwargs(auth_client),
  144. }
  145. versioned_apps = {
  146. v: self._create_versioned_app(v, **kwargs) for v in self.versions
  147. }
  148. for v, versioned_app in versioned_apps.items():
  149. app.mount("/" + v.prefix, versioned_app)
  150. return app