|
@@ -3,7 +3,6 @@
|
|
|
import logging
|
|
|
from typing import Any
|
|
|
from typing import Callable
|
|
|
-from typing import Dict
|
|
|
from typing import List
|
|
|
from typing import Optional
|
|
|
from typing import Set
|
|
@@ -12,11 +11,8 @@ from asgiref.sync import sync_to_async
|
|
|
from fastapi import Depends
|
|
|
from fastapi import FastAPI
|
|
|
from fastapi import Request
|
|
|
-from fastapi.exceptions import HTTPException
|
|
|
from fastapi.exceptions import RequestValidationError
|
|
|
from fastapi.security import OAuth2AuthorizationCodeBearer
|
|
|
-from starlette.status import HTTP_401_UNAUTHORIZED
|
|
|
-from starlette.status import HTTP_403_FORBIDDEN
|
|
|
from starlette.types import ASGIApp
|
|
|
|
|
|
from clean_python import Conflict
|
|
@@ -24,9 +20,11 @@ from clean_python import DoesNotExist
|
|
|
from clean_python import Gateway
|
|
|
from clean_python import PermissionDenied
|
|
|
from clean_python import Unauthorized
|
|
|
-from clean_python.oauth2 import OAuth2AccessTokenVerifier
|
|
|
-from clean_python.oauth2 import OAuth2Settings
|
|
|
+from clean_python.oauth2 import OAuth2SPAClientSettings
|
|
|
+from clean_python.oauth2 import TokenVerifier
|
|
|
+from clean_python.oauth2 import TokenVerifierSettings
|
|
|
|
|
|
+from .context import ctx
|
|
|
from .context import RequestMiddleware
|
|
|
from .error_responses import BadRequest
|
|
|
from .error_responses import conflict_handler
|
|
@@ -47,7 +45,7 @@ logger = logging.getLogger(__name__)
|
|
|
__all__ = ["Service"]
|
|
|
|
|
|
|
|
|
-class OAuth2Dependable(OAuth2AuthorizationCodeBearer):
|
|
|
+class OAuth2WithClientDependable(OAuth2AuthorizationCodeBearer):
|
|
|
"""A fastapi 'dependable' configuring OAuth2.
|
|
|
|
|
|
This does two things:
|
|
@@ -55,45 +53,53 @@ class OAuth2Dependable(OAuth2AuthorizationCodeBearer):
|
|
|
- (through FastAPI magic) add the scheme to the OpenAPI spec
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, scope, settings: OAuth2Settings):
|
|
|
- self.verifier = sync_to_async(
|
|
|
- OAuth2AccessTokenVerifier(
|
|
|
- scope,
|
|
|
- issuer=settings.issuer,
|
|
|
- resource_server_id=settings.resource_server_id,
|
|
|
- algorithms=settings.algorithms,
|
|
|
- admin_users=settings.admin_users,
|
|
|
- ),
|
|
|
- thread_sensitive=False,
|
|
|
- )
|
|
|
+ def __init__(
|
|
|
+ self, settings: TokenVerifierSettings, client: OAuth2SPAClientSettings
|
|
|
+ ):
|
|
|
+ self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
|
|
|
super().__init__(
|
|
|
- authorizationUrl=settings.authorization_url,
|
|
|
- tokenUrl=settings.token_url,
|
|
|
- scopes={
|
|
|
- f"{settings.resource_server_id}*:readwrite": "Full read/write access"
|
|
|
- },
|
|
|
+ authorizationUrl=str(client.authorization_url),
|
|
|
+ tokenUrl=str(client.token_url),
|
|
|
)
|
|
|
|
|
|
async def __call__(self, request: Request) -> None:
|
|
|
- token = await super().__call__(request)
|
|
|
- try:
|
|
|
- await self.verifier(token)
|
|
|
- except Unauthorized:
|
|
|
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED)
|
|
|
- except PermissionDenied:
|
|
|
- raise HTTPException(status_code=HTTP_403_FORBIDDEN)
|
|
|
+ ctx.claims = await self.verifier(request.headers.get("Authorization"))
|
|
|
+
|
|
|
+
|
|
|
+class OAuth2WithoutClientDependable:
|
|
|
+ """A fastapi 'dependable' configuring OAuth2.
|
|
|
+
|
|
|
+ This does one thing:
|
|
|
+ - Verify the token in each request
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, settings: TokenVerifierSettings):
|
|
|
+ self.verifier = sync_to_async(TokenVerifier(settings), thread_sensitive=False)
|
|
|
+
|
|
|
+ async def __call__(self, request: Request) -> None:
|
|
|
+ ctx.claims = await self.verifier(request.headers.get("Authorization"))
|
|
|
|
|
|
|
|
|
-def fastapi_oauth_kwargs(auth: Optional[OAuth2Settings]) -> Dict:
|
|
|
+def get_auth_kwargs(
|
|
|
+ auth: Optional[TokenVerifierSettings],
|
|
|
+ auth_client: Optional[OAuth2SPAClientSettings],
|
|
|
+) -> None:
|
|
|
if auth is None:
|
|
|
return {}
|
|
|
- return {
|
|
|
- "dependencies": [Depends(OAuth2Dependable(scope="*:readwrite", settings=auth))],
|
|
|
- "swagger_ui_init_oauth": {
|
|
|
- "clientId": auth.client_id,
|
|
|
- "usePkceWithAuthorizationCodeGrant": True,
|
|
|
- },
|
|
|
- }
|
|
|
+ if auth_client is None:
|
|
|
+ return {
|
|
|
+ "dependencies": [Depends(OAuth2WithoutClientDependable(settings=auth))],
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ return {
|
|
|
+ "dependencies": [
|
|
|
+ Depends(OAuth2WithClientDependable(settings=auth, client=auth_client))
|
|
|
+ ],
|
|
|
+ "swagger_ui_init_oauth": {
|
|
|
+ "clientId": auth_client.client_id,
|
|
|
+ "usePkceWithAuthorizationCodeGrant": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
|
|
|
|
|
|
async def health_check():
|
|
@@ -143,7 +149,8 @@ class Service:
|
|
|
app = FastAPI(
|
|
|
version=version.prefix,
|
|
|
tags=sorted(
|
|
|
- [x.get_openapi_tag().dict() for x in resources], key=lambda x: x["name"]
|
|
|
+ [x.get_openapi_tag().model_dump() for x in resources],
|
|
|
+ key=lambda x: x["name"],
|
|
|
),
|
|
|
**kwargs,
|
|
|
)
|
|
@@ -171,7 +178,8 @@ class Service:
|
|
|
title: str,
|
|
|
description: str,
|
|
|
hostname: str,
|
|
|
- auth: Optional[OAuth2Settings] = None,
|
|
|
+ auth: Optional[TokenVerifierSettings] = None,
|
|
|
+ auth_client: Optional[OAuth2SPAClientSettings] = None,
|
|
|
on_startup: Optional[List[Callable[[], Any]]] = None,
|
|
|
access_logger_gateway: Optional[Gateway] = None,
|
|
|
) -> ASGIApp:
|
|
@@ -185,7 +193,7 @@ class Service:
|
|
|
kwargs = {
|
|
|
"title": title,
|
|
|
"description": description,
|
|
|
- **fastapi_oauth_kwargs(auth),
|
|
|
+ **get_auth_kwargs(auth, auth_client),
|
|
|
}
|
|
|
versioned_apps = {
|
|
|
v: self._create_versioned_app(v, **kwargs) for v in self.versions
|