service.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import logging
  2. from typing import Any, Callable, Dict, List, Optional, Set
  3. from asgiref.sync import sync_to_async
  4. from fastapi import Depends, FastAPI, Request
  5. from fastapi.exceptions import HTTPException, RequestValidationError
  6. from fastapi.security import OAuth2AuthorizationCodeBearer
  7. from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
  8. from starlette.types import ASGIApp
  9. from .context import RequestMiddleware
  10. from .error_responses import (
  11. BadRequest,
  12. conflict_handler,
  13. DefaultErrorResponse,
  14. not_found_handler,
  15. not_implemented_handler,
  16. permission_denied_handler,
  17. unauthorized_handler,
  18. validation_error_handler,
  19. ValidationErrorResponse,
  20. )
  21. from clean_python.base.domain.exceptions import Conflict, DoesNotExist, PermissionDenied, Unauthorized
  22. from .fastapi_access_logger import FastAPIAccessLogger
  23. from clean_python.base.infrastructure.gateway import Gateway
  24. from clean_python.oauth2.oauth2 import OAuth2AccessTokenVerifier, OAuth2Settings
  25. from .resource import APIVersion, clean_resources, Resource
  26. logger = logging.getLogger(__name__)
  27. class OAuth2Dependable(OAuth2AuthorizationCodeBearer):
  28. """A fastapi 'dependable' configuring OAuth2.
  29. This does two things:
  30. - Verify the token in each request
  31. - (through FastAPI magic) add the scheme to the OpenAPI spec
  32. """
  33. def __init__(self, scope, settings: OAuth2Settings):
  34. self.verifier = sync_to_async(
  35. OAuth2AccessTokenVerifier(
  36. scope,
  37. issuer=settings.issuer,
  38. resource_server_id=settings.resource_server_id,
  39. algorithms=settings.algorithms,
  40. admin_users=settings.admin_users,
  41. ),
  42. thread_sensitive=False,
  43. )
  44. super().__init__(
  45. authorizationUrl=settings.authorization_url,
  46. tokenUrl=settings.token_url,
  47. scopes={
  48. f"{settings.resource_server_id}*:readwrite": "Full read/write access"
  49. },
  50. )
  51. async def __call__(self, request: Request) -> None:
  52. token = await super().__call__(request)
  53. try:
  54. await self.verifier(token)
  55. except Unauthorized:
  56. raise HTTPException(status_code=HTTP_401_UNAUTHORIZED)
  57. except PermissionDenied:
  58. raise HTTPException(status_code=HTTP_403_FORBIDDEN)
  59. def fastapi_oauth_kwargs(auth: Optional[OAuth2Settings]) -> Dict:
  60. if auth is None:
  61. return {}
  62. return {
  63. "dependencies": [Depends(OAuth2Dependable(scope="*:readwrite", settings=auth))],
  64. "swagger_ui_init_oauth": {
  65. "clientId": auth.client_id,
  66. "usePkceWithAuthorizationCodeGrant": True,
  67. },
  68. }
  69. async def health_check():
  70. """Simple health check route"""
  71. return {"health": "OK"}
  72. class Service:
  73. resources: List[Resource]
  74. def __init__(self, *args: Resource):
  75. self.resources = clean_resources(args)
  76. @property
  77. def versions(self) -> Set[APIVersion]:
  78. return set([x.version for x in self.resources])
  79. def _create_root_app(
  80. self,
  81. title: str,
  82. description: str,
  83. hostname: str,
  84. on_startup: Optional[List[Callable[[], Any]]] = None,
  85. access_logger_gateway: Optional[Gateway] = None,
  86. ) -> FastAPI:
  87. app = FastAPI(
  88. title=title,
  89. description=description,
  90. on_startup=on_startup,
  91. servers=[
  92. {"url": f"{x.prefix}", "description": x.description}
  93. for x in self.versions
  94. ],
  95. root_path_in_servers=False,
  96. )
  97. app.middleware("http")(
  98. FastAPIAccessLogger(
  99. hostname=hostname, gateway_override=access_logger_gateway
  100. )
  101. )
  102. app.add_middleware(RequestMiddleware)
  103. app.get("/health", include_in_schema=False)(health_check)
  104. return app
  105. def _create_versioned_app(self, version: APIVersion, **kwargs) -> FastAPI:
  106. resources = [x for x in self.resources if x.version == version]
  107. app = FastAPI(
  108. version=version.prefix,
  109. tags=sorted(
  110. [x.get_openapi_tag().dict() for x in resources], key=lambda x: x["name"]
  111. ),
  112. **kwargs,
  113. )
  114. for resource in resources:
  115. app.include_router(
  116. resource.get_router(
  117. version,
  118. responses={
  119. "400": {"model": ValidationErrorResponse},
  120. "default": {"model": DefaultErrorResponse},
  121. },
  122. )
  123. )
  124. app.add_exception_handler(DoesNotExist, not_found_handler)
  125. app.add_exception_handler(Conflict, conflict_handler)
  126. app.add_exception_handler(RequestValidationError, validation_error_handler)
  127. app.add_exception_handler(BadRequest, validation_error_handler)
  128. app.add_exception_handler(NotImplementedError, not_implemented_handler)
  129. app.add_exception_handler(PermissionDenied, permission_denied_handler)
  130. app.add_exception_handler(Unauthorized, unauthorized_handler)
  131. return app
  132. def create_app(
  133. self,
  134. title: str,
  135. description: str,
  136. hostname: str,
  137. auth: Optional[OAuth2Settings] = None,
  138. on_startup: Optional[List[Callable[[], Any]]] = None,
  139. access_logger_gateway: Optional[Gateway] = None,
  140. ) -> ASGIApp:
  141. app = self._create_root_app(
  142. title=title,
  143. description=description,
  144. hostname=hostname,
  145. on_startup=on_startup,
  146. access_logger_gateway=access_logger_gateway,
  147. )
  148. kwargs = {
  149. "title": title,
  150. "description": description,
  151. **fastapi_oauth_kwargs(auth),
  152. }
  153. versioned_apps = {
  154. v: self._create_versioned_app(v, **kwargs) for v in self.versions
  155. }
  156. for v, versioned_app in versioned_apps.items():
  157. app.mount("/" + v.prefix, versioned_app)
  158. return app