123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- # (c) Nelen & Schuurmans
- from enum import Enum
- from functools import partial
- from typing import Any
- from typing import Callable
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Sequence
- from typing import Type
- from fastapi import Depends
- from fastapi.routing import APIRouter
- from clean_python import ValueObject
- from .security import RequiresScope
- __all__ = [
- "Resource",
- "get",
- "post",
- "put",
- "patch",
- "delete",
- "APIVersion",
- "Stability",
- "v",
- "clean_resources",
- ]
- class Stability(str, Enum):
- STABLE = "stable"
- BETA = "beta"
- ALPHA = "alpha"
- @property
- def description(self) -> str:
- return DESCRIPTIONS[self]
- def decrease(self) -> "Stability":
- index = STABILITY_ORDER.index(self)
- if index == 0:
- raise ValueError(f"Cannot decrease stability of {self}")
- return STABILITY_ORDER[index - 1]
- STABILITY_ORDER = [Stability.ALPHA, Stability.BETA, Stability.STABLE]
- DESCRIPTIONS = {
- Stability.STABLE: "The stable API version.",
- Stability.BETA: "Backwards incompatible changes will be announced beforehand.",
- Stability.ALPHA: "May get backwards incompatible changes without warning.",
- }
- class APIVersion(ValueObject):
- version: int
- stability: Stability
- @property
- def prefix(self) -> str:
- result = f"v{self.version}"
- if self.stability is not Stability.STABLE:
- result += f"-{self.stability.value}"
- return result
- @property
- def description(self) -> str:
- return self.stability.description
- def decrease_stability(self) -> "APIVersion":
- return APIVersion(version=self.version, stability=self.stability.decrease())
- def http_method(path: str, scope: Optional[str] = None, **route_options):
- def wrapper(unbound_method: Callable[..., Any]):
- setattr(
- unbound_method,
- "http_method",
- (path, scope, route_options),
- )
- return unbound_method
- return wrapper
- def v(version: int, stability: str = "stable") -> APIVersion:
- return APIVersion(version=version, stability=Stability(stability))
- get = partial(http_method, methods=["GET"])
- post = partial(http_method, methods=["POST"])
- put = partial(http_method, methods=["PUT"])
- patch = partial(http_method, methods=["PATCH"])
- delete = partial(http_method, methods=["DELETE"])
- class OpenApiTag(ValueObject):
- name: str
- description: Optional[str]
- class Resource:
- version: APIVersion
- name: str
- def __init_subclass__(cls, version: APIVersion, name: str = ""):
- cls.version = version
- cls.name = name
- super().__init_subclass__()
- @classmethod
- def with_version(cls, version: APIVersion) -> Type["Resource"]:
- class DynamicResource(cls, version=version, name=cls.name): # type: ignore
- pass
- DynamicResource.__doc__ = cls.__doc__
- return DynamicResource
- def get_less_stable(self, resources: Dict[APIVersion, "Resource"]) -> "Resource":
- """Fetch a less stable version of this resource from 'resources'
- If it doesn't exist, create it dynamically.
- """
- less_stable_version = self.version.decrease_stability()
- # Fetch the less stable resource; generate it if it does not exist
- try:
- less_stable_resource = resources[less_stable_version]
- except KeyError:
- less_stable_resource = self.__class__.with_version(less_stable_version)()
- # Validate the less stable version
- if less_stable_resource.__class__.__bases__ != (self.__class__,):
- raise RuntimeError(
- f"{less_stable_resource} should be a direct subclass of {self}"
- )
- return less_stable_resource
- def _endpoints(self):
- for attr_name in dir(self):
- if attr_name.startswith("_"):
- continue
- endpoint = getattr(self, attr_name)
- if not hasattr(endpoint, "http_method"):
- continue
- yield endpoint
- def get_openapi_tag(self) -> OpenApiTag:
- return OpenApiTag(
- name=self.name,
- description=self.__class__.__doc__,
- )
- def get_router(
- self, version: APIVersion, responses: Optional[Dict[str, Dict[str, Any]]] = None
- ) -> APIRouter:
- assert version == self.version
- router = APIRouter()
- operation_ids = set()
- for endpoint in self._endpoints():
- path, scope, route_options = endpoint.http_method
- operation_id = endpoint.__name__
- if operation_id in operation_ids:
- raise RuntimeError(
- "Multiple operations {operation_id} configured in {self}"
- )
- operation_ids.add(operation_id)
- # The 'name' is used for reverse lookups (request.path_for): include the
- # version prefix so that we can uniquely refer to an operation.
- name = version.prefix + "/" + endpoint.__name__
- # 'scope' is implemented using FastAPI's dependency injection system
- if scope is not None:
- route_options.setdefault("dependencies", [])
- route_options["dependencies"].append(Depends(RequiresScope(scope)))
- # Update responses with route_options responses or use latter if not set
- if "responses" in route_options:
- responses = {**(responses or {}), **route_options.pop("responses")}
- router.add_api_route(
- path,
- endpoint,
- tags=[self.name],
- operation_id=endpoint.__name__,
- name=name,
- responses=responses,
- **route_options,
- )
- return router
- def clean_resources_same_name(resources: List[Resource]) -> List[Resource]:
- dct = {x.version: x for x in resources}
- if len(dct) != len(resources):
- raise RuntimeError(
- f"Resource with name {resources[0].name} "
- f"is defined multiple times with the same version."
- )
- for stability in [Stability.STABLE, Stability.BETA]:
- tmp_resources = {k: v for (k, v) in dct.items() if k.stability is stability}
- for version, resource in tmp_resources.items():
- dct[version.decrease_stability()] = resource.get_less_stable(dct)
- return list(dct.values())
- def clean_resources(resources: Sequence[Resource]) -> List[Resource]:
- """Ensure that resources are consistent:
- - ordered by name
- - (tag, version) combinations should be unique
- - for stable resources, beta & alpha are autocreated if needed
- - for beta resources, alpha is autocreated if needed
- """
- result = []
- names = {x.name for x in resources}
- for name in sorted(names):
- result.extend(
- clean_resources_same_name([x for x in resources if x.name == name])
- )
- return result
|