resource.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # (c) Nelen & Schuurmans
  2. from enum import Enum
  3. from functools import partial
  4. from typing import Any
  5. from typing import Callable
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from typing import Sequence
  10. from typing import Type
  11. from fastapi import Depends
  12. from fastapi.routing import APIRouter
  13. from clean_python import ValueObject
  14. from .security import RequiresScope
  15. __all__ = [
  16. "Resource",
  17. "get",
  18. "post",
  19. "put",
  20. "patch",
  21. "delete",
  22. "APIVersion",
  23. "Stability",
  24. "v",
  25. "clean_resources",
  26. ]
  27. class Stability(str, Enum):
  28. STABLE = "stable"
  29. BETA = "beta"
  30. ALPHA = "alpha"
  31. @property
  32. def description(self) -> str:
  33. return DESCRIPTIONS[self]
  34. def decrease(self) -> "Stability":
  35. index = STABILITY_ORDER.index(self)
  36. if index == 0:
  37. raise ValueError(f"Cannot decrease stability of {self}")
  38. return STABILITY_ORDER[index - 1]
  39. STABILITY_ORDER = [Stability.ALPHA, Stability.BETA, Stability.STABLE]
  40. DESCRIPTIONS = {
  41. Stability.STABLE: "The stable API version.",
  42. Stability.BETA: "Backwards incompatible changes will be announced beforehand.",
  43. Stability.ALPHA: "May get backwards incompatible changes without warning.",
  44. }
  45. class APIVersion(ValueObject):
  46. version: int
  47. stability: Stability
  48. @property
  49. def prefix(self) -> str:
  50. result = f"v{self.version}"
  51. if self.stability is not Stability.STABLE:
  52. result += f"-{self.stability.value}"
  53. return result
  54. @property
  55. def description(self) -> str:
  56. return self.stability.description
  57. def decrease_stability(self) -> "APIVersion":
  58. return APIVersion(version=self.version, stability=self.stability.decrease())
  59. def http_method(path: str, scope: Optional[str] = None, **route_options):
  60. def wrapper(unbound_method: Callable[..., Any]):
  61. setattr(
  62. unbound_method,
  63. "http_method",
  64. (path, scope, route_options),
  65. )
  66. return unbound_method
  67. return wrapper
  68. def v(version: int, stability: str = "stable") -> APIVersion:
  69. return APIVersion(version=version, stability=Stability(stability))
  70. get = partial(http_method, methods=["GET"])
  71. post = partial(http_method, methods=["POST"])
  72. put = partial(http_method, methods=["PUT"])
  73. patch = partial(http_method, methods=["PATCH"])
  74. delete = partial(http_method, methods=["DELETE"])
  75. class OpenApiTag(ValueObject):
  76. name: str
  77. description: Optional[str]
  78. class Resource:
  79. version: APIVersion
  80. name: str
  81. def __init_subclass__(cls, version: APIVersion, name: str = ""):
  82. cls.version = version
  83. cls.name = name
  84. super().__init_subclass__()
  85. @classmethod
  86. def with_version(cls, version: APIVersion) -> Type["Resource"]:
  87. class DynamicResource(cls, version=version, name=cls.name): # type: ignore
  88. pass
  89. DynamicResource.__doc__ = cls.__doc__
  90. return DynamicResource
  91. def get_less_stable(self, resources: Dict[APIVersion, "Resource"]) -> "Resource":
  92. """Fetch a less stable version of this resource from 'resources'
  93. If it doesn't exist, create it dynamically.
  94. """
  95. less_stable_version = self.version.decrease_stability()
  96. # Fetch the less stable resource; generate it if it does not exist
  97. try:
  98. less_stable_resource = resources[less_stable_version]
  99. except KeyError:
  100. less_stable_resource = self.__class__.with_version(less_stable_version)()
  101. # Validate the less stable version
  102. if less_stable_resource.__class__.__bases__ != (self.__class__,):
  103. raise RuntimeError(
  104. f"{less_stable_resource} should be a direct subclass of {self}"
  105. )
  106. return less_stable_resource
  107. def _endpoints(self):
  108. for attr_name in dir(self):
  109. if attr_name.startswith("_"):
  110. continue
  111. endpoint = getattr(self, attr_name)
  112. if not hasattr(endpoint, "http_method"):
  113. continue
  114. yield endpoint
  115. def get_openapi_tag(self) -> OpenApiTag:
  116. return OpenApiTag(
  117. name=self.name,
  118. description=self.__class__.__doc__,
  119. )
  120. def get_router(
  121. self, version: APIVersion, responses: Optional[Dict[str, Dict[str, Any]]] = None
  122. ) -> APIRouter:
  123. assert version == self.version
  124. router = APIRouter()
  125. operation_ids = set()
  126. for endpoint in self._endpoints():
  127. path, scope, route_options = endpoint.http_method
  128. operation_id = endpoint.__name__
  129. if operation_id in operation_ids:
  130. raise RuntimeError(
  131. "Multiple operations {operation_id} configured in {self}"
  132. )
  133. operation_ids.add(operation_id)
  134. # The 'name' is used for reverse lookups (request.path_for): include the
  135. # version prefix so that we can uniquely refer to an operation.
  136. name = version.prefix + "/" + endpoint.__name__
  137. # 'scope' is implemented using FastAPI's dependency injection system
  138. if scope is not None:
  139. route_options.setdefault("dependencies", [])
  140. route_options["dependencies"].append(Depends(RequiresScope(scope)))
  141. # Update responses with route_options responses or use latter if not set
  142. if "responses" in route_options:
  143. responses = {**(responses or {}), **route_options.pop("responses")}
  144. router.add_api_route(
  145. path,
  146. endpoint,
  147. tags=[self.name],
  148. operation_id=endpoint.__name__,
  149. name=name,
  150. responses=responses,
  151. **route_options,
  152. )
  153. return router
  154. def clean_resources_same_name(resources: List[Resource]) -> List[Resource]:
  155. dct = {x.version: x for x in resources}
  156. if len(dct) != len(resources):
  157. raise RuntimeError(
  158. f"Resource with name {resources[0].name} "
  159. f"is defined multiple times with the same version."
  160. )
  161. for stability in [Stability.STABLE, Stability.BETA]:
  162. tmp_resources = {k: v for (k, v) in dct.items() if k.stability is stability}
  163. for version, resource in tmp_resources.items():
  164. dct[version.decrease_stability()] = resource.get_less_stable(dct)
  165. return list(dct.values())
  166. def clean_resources(resources: Sequence[Resource]) -> List[Resource]:
  167. """Ensure that resources are consistent:
  168. - ordered by name
  169. - (tag, version) combinations should be unique
  170. - for stable resources, beta & alpha are autocreated if needed
  171. - for beta resources, alpha is autocreated if needed
  172. """
  173. result = []
  174. names = {x.name for x in resources}
  175. for name in sorted(names):
  176. result.extend(
  177. clean_resources_same_name([x for x in resources if x.name == name])
  178. )
  179. return result