from enum import Enum from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Type from fastapi.routing import APIRouter from clean_python.base.domain.value_object import ValueObject __all__ = [ "Resource", "get", "post", "put", "patch", "delete", "APIVersion", "Stability", "v", "clean_resources", ] class Stability(str, Enum): STABLE = "stable" BETA = "beta" ALPHA = "alpha" @property def description(self) -> str: return DESCRIPTIONS[self] def decrease(self) -> "Stability": index = STABILITY_ORDER.index(self) if index == 0: raise ValueError(f"Cannot decrease stability of {self}") return STABILITY_ORDER[index - 1] STABILITY_ORDER = [Stability.ALPHA, Stability.BETA, Stability.STABLE] DESCRIPTIONS = { Stability.STABLE: "The stable API version.", Stability.BETA: "Backwards incompatible changes will be announced beforehand.", Stability.ALPHA: "May get backwards incompatible changes without warning.", } class APIVersion(ValueObject): version: int stability: Stability @property def prefix(self) -> str: result = f"v{self.version}" if self.stability is not Stability.STABLE: result += f"-{self.stability.value}" return result @property def description(self) -> str: return self.stability.description def decrease_stability(self) -> "APIVersion": return APIVersion(version=self.version, stability=self.stability.decrease()) def http_method(path: str, **route_options): def wrapper(unbound_method: Callable[..., Any]): setattr( unbound_method, "http_method", (path, route_options), ) return unbound_method return wrapper def v(version: int, stability: str = "stable") -> APIVersion: return APIVersion(version=version, stability=Stability(stability)) get = partial(http_method, methods=["GET"]) post = partial(http_method, methods=["POST"]) put = partial(http_method, methods=["PUT"]) patch = partial(http_method, methods=["PATCH"]) delete = partial(http_method, methods=["DELETE"]) class OpenApiTag(ValueObject): name: str description: Optional[str] class Resource: version: APIVersion name: str def __init_subclass__(cls, version: APIVersion, name: str = ""): cls.version = version cls.name = name super().__init_subclass__() @classmethod def with_version(cls, version: APIVersion) -> Type["Resource"]: class DynamicResource(cls, version=version, name=cls.name): # type: ignore pass DynamicResource.__doc__ = cls.__doc__ return DynamicResource def get_less_stable(self, resources: Dict[APIVersion, "Resource"]) -> "Resource": """Fetch a less stable version of this resource from 'resources' If it doesn't exist, create it dynamically. """ less_stable_version = self.version.decrease_stability() # Fetch the less stable resource; generate it if it does not exist try: less_stable_resource = resources[less_stable_version] except KeyError: less_stable_resource = self.__class__.with_version(less_stable_version)() # Validate the less stable version if less_stable_resource.__class__.__bases__ != (self.__class__,): raise RuntimeError( f"{less_stable_resource} should be a direct subclass of {self}" ) return less_stable_resource def _endpoints(self): for attr_name in dir(self): if attr_name.startswith("_"): continue endpoint = getattr(self, attr_name) if not hasattr(endpoint, "http_method"): continue yield endpoint def get_openapi_tag(self) -> OpenApiTag: return OpenApiTag( name=self.name, description=self.__class__.__doc__, ) def get_router( self, version: APIVersion, responses: Optional[Dict[str, Dict[str, Any]]] = None ) -> APIRouter: assert version == self.version router = APIRouter() operation_ids = set() for endpoint in self._endpoints(): path, route_options = endpoint.http_method operation_id = endpoint.__name__ if operation_id in operation_ids: raise RuntimeError( "Multiple operations {operation_id} configured in {self}" ) operation_ids.add(operation_id) # The 'name' is used for reverse lookups (request.path_for): include the # version prefix so that we can uniquely refer to an operation. name = version.prefix + "/" + endpoint.__name__ router.add_api_route( path, endpoint, tags=[self.name], operation_id=endpoint.__name__, name=name, responses=responses, **route_options, ) return router def clean_resources_same_name(resources: List[Resource]) -> List[Resource]: dct = {x.version: x for x in resources} if len(dct) != len(resources): raise RuntimeError( f"Resource with name {resources[0].name} " f"is defined multiple times with the same version." ) for stability in [Stability.STABLE, Stability.BETA]: tmp_resources = {k: v for (k, v) in dct.items() if k.stability is stability} for version, resource in tmp_resources.items(): dct[version.decrease_stability()] = resource.get_less_stable(dct) return list(dct.values()) def clean_resources(resources: Sequence[Resource]) -> List[Resource]: """Ensure that resources are consistent: - ordered by name - (tag, version) combinations should be unique - for stable resources, beta & alpha are autocreated if needed - for beta resources, alpha is autocreated if needed """ result = [] names = {x.name for x in resources} for name in sorted(names): result.extend( clean_resources_same_name([x for x in resources if x.name == name]) ) return result