service.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # (c) Nelen & Schuurmans
  2. from typing import Any
  3. from typing import Callable
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Set
  8. from uuid import UUID
  9. from uuid import uuid4
  10. from fastapi import Depends
  11. from fastapi import FastAPI
  12. from fastapi import Header
  13. from fastapi import Request
  14. from fastapi.exceptions import RequestValidationError
  15. from starlette.types import ASGIApp
  16. from clean_python import BadRequest
  17. from clean_python import Conflict
  18. from clean_python import ctx
  19. from clean_python import DoesNotExist
  20. from clean_python import Gateway
  21. from clean_python import PermissionDenied
  22. from clean_python import Unauthorized
  23. from clean_python.oauth2 import OAuth2SPAClientSettings
  24. from clean_python.oauth2 import Token
  25. from clean_python.oauth2 import TokenVerifierSettings
  26. from .error_responses import conflict_handler
  27. from .error_responses import DefaultErrorResponse
  28. from .error_responses import not_found_handler
  29. from .error_responses import not_implemented_handler
  30. from .error_responses import permission_denied_handler
  31. from .error_responses import unauthorized_handler
  32. from .error_responses import validation_error_handler
  33. from .error_responses import ValidationErrorResponse
  34. from .fastapi_access_logger import FastAPIAccessLogger
  35. from .resource import APIVersion
  36. from .resource import clean_resources
  37. from .resource import Resource
  38. from .security import get_token
  39. from .security import JWTBearerTokenSchema
  40. from .security import OAuth2SPAClientSchema
  41. from .security import set_verifier
  42. __all__ = ["Service"]
  43. def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> Dict[str, Any]:
  44. if auth_client is None:
  45. return {
  46. "dependencies": [Depends(JWTBearerTokenSchema()), Depends(set_context)],
  47. }
  48. else:
  49. return {
  50. "dependencies": [
  51. Depends(OAuth2SPAClientSchema(client=auth_client)),
  52. Depends(set_context),
  53. ],
  54. "swagger_ui_init_oauth": {
  55. "clientId": auth_client.client_id,
  56. "usePkceWithAuthorizationCodeGrant": True,
  57. },
  58. }
  59. async def set_context(
  60. request: Request,
  61. token: Token = Depends(get_token),
  62. x_correlation_id: UUID = Header(default_factory=uuid4),
  63. ) -> None:
  64. ctx.path = request.url
  65. ctx.user = token.user
  66. ctx.tenant = token.tenant
  67. ctx.correlation_id = x_correlation_id
  68. async def health_check():
  69. """Simple health check route"""
  70. return {"health": "OK"}
  71. class Service:
  72. resources: List[Resource]
  73. def __init__(self, *args: Resource):
  74. self.resources = clean_resources(args)
  75. @property
  76. def versions(self) -> Set[APIVersion]:
  77. return set([x.version for x in self.resources])
  78. def _create_root_app(
  79. self,
  80. title: str,
  81. description: str,
  82. hostname: str,
  83. on_startup: Optional[List[Callable[[], Any]]] = None,
  84. access_logger_gateway: Optional[Gateway] = None,
  85. ) -> FastAPI:
  86. app = FastAPI(
  87. title=title,
  88. description=description,
  89. on_startup=on_startup,
  90. servers=[
  91. {"url": f"{x.prefix}", "description": x.description}
  92. for x in self.versions
  93. ],
  94. root_path_in_servers=False,
  95. )
  96. app.middleware("http")(
  97. FastAPIAccessLogger(
  98. hostname=hostname, gateway_override=access_logger_gateway
  99. )
  100. )
  101. app.get("/health", include_in_schema=False)(health_check)
  102. return app
  103. def _create_versioned_app(self, version: APIVersion, **kwargs) -> FastAPI:
  104. resources = [x for x in self.resources if x.version == version]
  105. app = FastAPI(
  106. version=version.prefix,
  107. tags=sorted(
  108. [x.get_openapi_tag().model_dump() for x in resources],
  109. key=lambda x: x["name"],
  110. ),
  111. **kwargs,
  112. )
  113. for resource in resources:
  114. app.include_router(
  115. resource.get_router(
  116. version,
  117. responses={
  118. "400": {"model": ValidationErrorResponse},
  119. "default": {"model": DefaultErrorResponse},
  120. },
  121. )
  122. )
  123. app.add_exception_handler(DoesNotExist, not_found_handler)
  124. app.add_exception_handler(Conflict, conflict_handler)
  125. app.add_exception_handler(RequestValidationError, validation_error_handler)
  126. app.add_exception_handler(BadRequest, validation_error_handler)
  127. app.add_exception_handler(NotImplementedError, not_implemented_handler)
  128. app.add_exception_handler(PermissionDenied, permission_denied_handler)
  129. app.add_exception_handler(Unauthorized, unauthorized_handler)
  130. return app
  131. def create_app(
  132. self,
  133. title: str,
  134. description: str,
  135. hostname: str,
  136. auth: Optional[TokenVerifierSettings] = None,
  137. auth_client: Optional[OAuth2SPAClientSettings] = None,
  138. on_startup: Optional[List[Callable[[], Any]]] = None,
  139. access_logger_gateway: Optional[Gateway] = None,
  140. ) -> ASGIApp:
  141. set_verifier(auth)
  142. app = self._create_root_app(
  143. title=title,
  144. description=description,
  145. hostname=hostname,
  146. on_startup=on_startup,
  147. access_logger_gateway=access_logger_gateway,
  148. )
  149. kwargs = {
  150. "title": title,
  151. "description": description,
  152. **get_auth_kwargs(auth_client),
  153. }
  154. versioned_apps = {
  155. v: self._create_versioned_app(v, **kwargs) for v in self.versions
  156. }
  157. for v, versioned_app in versioned_apps.items():
  158. app.mount("/" + v.prefix, versioned_app)
  159. return app