resource.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # (c) Nelen & Schuurmans
  2. from enum import Enum
  3. from functools import partial
  4. from typing import Any
  5. from typing import Callable
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from typing import Sequence
  10. from typing import Type
  11. from fastapi.routing import APIRouter
  12. from clean_python import ValueObject
  13. __all__ = [
  14. "Resource",
  15. "get",
  16. "post",
  17. "put",
  18. "patch",
  19. "delete",
  20. "APIVersion",
  21. "Stability",
  22. "v",
  23. "clean_resources",
  24. ]
  25. class Stability(str, Enum):
  26. STABLE = "stable"
  27. BETA = "beta"
  28. ALPHA = "alpha"
  29. @property
  30. def description(self) -> str:
  31. return DESCRIPTIONS[self]
  32. def decrease(self) -> "Stability":
  33. index = STABILITY_ORDER.index(self)
  34. if index == 0:
  35. raise ValueError(f"Cannot decrease stability of {self}")
  36. return STABILITY_ORDER[index - 1]
  37. STABILITY_ORDER = [Stability.ALPHA, Stability.BETA, Stability.STABLE]
  38. DESCRIPTIONS = {
  39. Stability.STABLE: "The stable API version.",
  40. Stability.BETA: "Backwards incompatible changes will be announced beforehand.",
  41. Stability.ALPHA: "May get backwards incompatible changes without warning.",
  42. }
  43. class APIVersion(ValueObject):
  44. version: int
  45. stability: Stability
  46. @property
  47. def prefix(self) -> str:
  48. result = f"v{self.version}"
  49. if self.stability is not Stability.STABLE:
  50. result += f"-{self.stability.value}"
  51. return result
  52. @property
  53. def description(self) -> str:
  54. return self.stability.description
  55. def decrease_stability(self) -> "APIVersion":
  56. return APIVersion(version=self.version, stability=self.stability.decrease())
  57. def http_method(path: str, **route_options):
  58. def wrapper(unbound_method: Callable[..., Any]):
  59. setattr(
  60. unbound_method,
  61. "http_method",
  62. (path, route_options),
  63. )
  64. return unbound_method
  65. return wrapper
  66. def v(version: int, stability: str = "stable") -> APIVersion:
  67. return APIVersion(version=version, stability=Stability(stability))
  68. get = partial(http_method, methods=["GET"])
  69. post = partial(http_method, methods=["POST"])
  70. put = partial(http_method, methods=["PUT"])
  71. patch = partial(http_method, methods=["PATCH"])
  72. delete = partial(http_method, methods=["DELETE"])
  73. class OpenApiTag(ValueObject):
  74. name: str
  75. description: Optional[str]
  76. class Resource:
  77. version: APIVersion
  78. name: str
  79. def __init_subclass__(cls, version: APIVersion, name: str = ""):
  80. cls.version = version
  81. cls.name = name
  82. super().__init_subclass__()
  83. @classmethod
  84. def with_version(cls, version: APIVersion) -> Type["Resource"]:
  85. class DynamicResource(cls, version=version, name=cls.name): # type: ignore
  86. pass
  87. DynamicResource.__doc__ = cls.__doc__
  88. return DynamicResource
  89. def get_less_stable(self, resources: Dict[APIVersion, "Resource"]) -> "Resource":
  90. """Fetch a less stable version of this resource from 'resources'
  91. If it doesn't exist, create it dynamically.
  92. """
  93. less_stable_version = self.version.decrease_stability()
  94. # Fetch the less stable resource; generate it if it does not exist
  95. try:
  96. less_stable_resource = resources[less_stable_version]
  97. except KeyError:
  98. less_stable_resource = self.__class__.with_version(less_stable_version)()
  99. # Validate the less stable version
  100. if less_stable_resource.__class__.__bases__ != (self.__class__,):
  101. raise RuntimeError(
  102. f"{less_stable_resource} should be a direct subclass of {self}"
  103. )
  104. return less_stable_resource
  105. def _endpoints(self):
  106. for attr_name in dir(self):
  107. if attr_name.startswith("_"):
  108. continue
  109. endpoint = getattr(self, attr_name)
  110. if not hasattr(endpoint, "http_method"):
  111. continue
  112. yield endpoint
  113. def get_openapi_tag(self) -> OpenApiTag:
  114. return OpenApiTag(
  115. name=self.name,
  116. description=self.__class__.__doc__,
  117. )
  118. def get_router(
  119. self, version: APIVersion, responses: Optional[Dict[str, Dict[str, Any]]] = None
  120. ) -> APIRouter:
  121. assert version == self.version
  122. router = APIRouter()
  123. operation_ids = set()
  124. for endpoint in self._endpoints():
  125. path, route_options = endpoint.http_method
  126. operation_id = endpoint.__name__
  127. if operation_id in operation_ids:
  128. raise RuntimeError(
  129. "Multiple operations {operation_id} configured in {self}"
  130. )
  131. operation_ids.add(operation_id)
  132. # The 'name' is used for reverse lookups (request.path_for): include the
  133. # version prefix so that we can uniquely refer to an operation.
  134. name = version.prefix + "/" + endpoint.__name__
  135. router.add_api_route(
  136. path,
  137. endpoint,
  138. tags=[self.name],
  139. operation_id=endpoint.__name__,
  140. name=name,
  141. responses=responses,
  142. **route_options,
  143. )
  144. return router
  145. def clean_resources_same_name(resources: List[Resource]) -> List[Resource]:
  146. dct = {x.version: x for x in resources}
  147. if len(dct) != len(resources):
  148. raise RuntimeError(
  149. f"Resource with name {resources[0].name} "
  150. f"is defined multiple times with the same version."
  151. )
  152. for stability in [Stability.STABLE, Stability.BETA]:
  153. tmp_resources = {k: v for (k, v) in dct.items() if k.stability is stability}
  154. for version, resource in tmp_resources.items():
  155. dct[version.decrease_stability()] = resource.get_less_stable(dct)
  156. return list(dct.values())
  157. def clean_resources(resources: Sequence[Resource]) -> List[Resource]:
  158. """Ensure that resources are consistent:
  159. - ordered by name
  160. - (tag, version) combinations should be unique
  161. - for stable resources, beta & alpha are autocreated if needed
  162. - for beta resources, alpha is autocreated if needed
  163. """
  164. result = []
  165. names = {x.name for x in resources}
  166. for name in sorted(names):
  167. result.extend(
  168. clean_resources_same_name([x for x in resources if x.name == name])
  169. )
  170. return result