service.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # (c) Nelen & Schuurmans
  2. import logging
  3. from typing import Any
  4. from typing import Callable
  5. from typing import List
  6. from typing import Optional
  7. from typing import Set
  8. from asgiref.sync import sync_to_async
  9. from fastapi import Depends
  10. from fastapi import FastAPI
  11. from fastapi import Request
  12. from fastapi.exceptions import RequestValidationError
  13. from fastapi.security import HTTPBearer
  14. from fastapi.security import OAuth2AuthorizationCodeBearer
  15. from starlette.types import ASGIApp
  16. from clean_python import Conflict
  17. from clean_python import DoesNotExist
  18. from clean_python import Gateway
  19. from clean_python import PermissionDenied
  20. from clean_python import Unauthorized
  21. from clean_python.oauth2 import OAuth2SPAClientSettings
  22. from clean_python.oauth2 import TokenVerifier
  23. from clean_python.oauth2 import TokenVerifierSettings
  24. from .context import ctx
  25. from .context import RequestMiddleware
  26. from .error_responses import BadRequest
  27. from .error_responses import conflict_handler
  28. from .error_responses import DefaultErrorResponse
  29. from .error_responses import not_found_handler
  30. from .error_responses import not_implemented_handler
  31. from .error_responses import permission_denied_handler
  32. from .error_responses import unauthorized_handler
  33. from .error_responses import validation_error_handler
  34. from .error_responses import ValidationErrorResponse
  35. from .fastapi_access_logger import FastAPIAccessLogger
  36. from .resource import APIVersion
  37. from .resource import clean_resources
  38. from .resource import Resource
  39. logger = logging.getLogger(__name__)
  40. __all__ = ["Service"]
  41. class OAuth2WithClientDependable(OAuth2AuthorizationCodeBearer):
  42. """A fastapi 'dependable' configuring OAuth2.
  43. This does two things:
  44. - Verify the token in each request
  45. - (through FastAPI magic) add the scheme to the OpenAPI spec
  46. """
  47. def __init__(
  48. self, settings: TokenVerifierSettings, client: OAuth2SPAClientSettings
  49. ):
  50. self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
  51. super().__init__(
  52. scheme_name="OAuth2 Authorization Code Flow with PKCE",
  53. authorizationUrl=str(client.authorization_url),
  54. tokenUrl=str(client.token_url),
  55. )
  56. async def __call__(self, request: Request) -> None:
  57. ctx.claims = await self.verifier(request.headers.get("Authorization"))
  58. class OAuth2WithoutClientDependable(HTTPBearer):
  59. """A fastapi 'dependable' configuring OAuth2.
  60. This does one thing:
  61. - Verify the token in each request
  62. """
  63. def __init__(self, settings: TokenVerifierSettings):
  64. self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
  65. super().__init__(scheme_name="JWT Bearer token", bearerFormat="JWT")
  66. async def __call__(self, request: Request) -> None:
  67. ctx.claims = await self.verifier(request.headers.get("Authorization"))
  68. def get_auth_kwargs(
  69. auth: Optional[TokenVerifierSettings],
  70. auth_client: Optional[OAuth2SPAClientSettings],
  71. ) -> None:
  72. if auth is None:
  73. return {}
  74. if auth_client is None:
  75. return {
  76. "dependencies": [Depends(OAuth2WithoutClientDependable(settings=auth))],
  77. }
  78. else:
  79. return {
  80. "dependencies": [
  81. Depends(OAuth2WithClientDependable(settings=auth, client=auth_client))
  82. ],
  83. "swagger_ui_init_oauth": {
  84. "clientId": auth_client.client_id,
  85. "usePkceWithAuthorizationCodeGrant": True,
  86. },
  87. }
  88. async def health_check():
  89. """Simple health check route"""
  90. return {"health": "OK"}
  91. class Service:
  92. resources: List[Resource]
  93. def __init__(self, *args: Resource):
  94. self.resources = clean_resources(args)
  95. @property
  96. def versions(self) -> Set[APIVersion]:
  97. return set([x.version for x in self.resources])
  98. def _create_root_app(
  99. self,
  100. title: str,
  101. description: str,
  102. hostname: str,
  103. on_startup: Optional[List[Callable[[], Any]]] = None,
  104. access_logger_gateway: Optional[Gateway] = None,
  105. ) -> FastAPI:
  106. app = FastAPI(
  107. title=title,
  108. description=description,
  109. on_startup=on_startup,
  110. servers=[
  111. {"url": f"{x.prefix}", "description": x.description}
  112. for x in self.versions
  113. ],
  114. root_path_in_servers=False,
  115. )
  116. app.middleware("http")(
  117. FastAPIAccessLogger(
  118. hostname=hostname, gateway_override=access_logger_gateway
  119. )
  120. )
  121. app.add_middleware(RequestMiddleware)
  122. app.get("/health", include_in_schema=False)(health_check)
  123. return app
  124. def _create_versioned_app(self, version: APIVersion, **kwargs) -> FastAPI:
  125. resources = [x for x in self.resources if x.version == version]
  126. app = FastAPI(
  127. version=version.prefix,
  128. tags=sorted(
  129. [x.get_openapi_tag().model_dump() for x in resources],
  130. key=lambda x: x["name"],
  131. ),
  132. **kwargs,
  133. )
  134. for resource in resources:
  135. app.include_router(
  136. resource.get_router(
  137. version,
  138. responses={
  139. "400": {"model": ValidationErrorResponse},
  140. "default": {"model": DefaultErrorResponse},
  141. },
  142. )
  143. )
  144. app.add_exception_handler(DoesNotExist, not_found_handler)
  145. app.add_exception_handler(Conflict, conflict_handler)
  146. app.add_exception_handler(RequestValidationError, validation_error_handler)
  147. app.add_exception_handler(BadRequest, validation_error_handler)
  148. app.add_exception_handler(NotImplementedError, not_implemented_handler)
  149. app.add_exception_handler(PermissionDenied, permission_denied_handler)
  150. app.add_exception_handler(Unauthorized, unauthorized_handler)
  151. return app
  152. def create_app(
  153. self,
  154. title: str,
  155. description: str,
  156. hostname: str,
  157. auth: Optional[TokenVerifierSettings] = None,
  158. auth_client: Optional[OAuth2SPAClientSettings] = None,
  159. on_startup: Optional[List[Callable[[], Any]]] = None,
  160. access_logger_gateway: Optional[Gateway] = None,
  161. ) -> ASGIApp:
  162. app = self._create_root_app(
  163. title=title,
  164. description=description,
  165. hostname=hostname,
  166. on_startup=on_startup,
  167. access_logger_gateway=access_logger_gateway,
  168. )
  169. kwargs = {
  170. "title": title,
  171. "description": description,
  172. **get_auth_kwargs(auth, auth_client),
  173. }
  174. versioned_apps = {
  175. v: self._create_versioned_app(v, **kwargs) for v in self.versions
  176. }
  177. for v, versioned_app in versioned_apps.items():
  178. app.mount("/" + v.prefix, versioned_app)
  179. return app