resource.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from enum import Enum
  2. from functools import partial
  3. from typing import Any
  4. from typing import Callable
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Sequence
  9. from typing import Type
  10. from fastapi.routing import APIRouter
  11. from clean_python.base.domain.value_object import ValueObject
  12. __all__ = [
  13. "Resource",
  14. "get",
  15. "post",
  16. "put",
  17. "patch",
  18. "delete",
  19. "APIVersion",
  20. "Stability",
  21. "v",
  22. "clean_resources",
  23. ]
  24. class Stability(str, Enum):
  25. STABLE = "stable"
  26. BETA = "beta"
  27. ALPHA = "alpha"
  28. @property
  29. def description(self) -> str:
  30. return DESCRIPTIONS[self]
  31. def decrease(self) -> "Stability":
  32. index = STABILITY_ORDER.index(self)
  33. if index == 0:
  34. raise ValueError(f"Cannot decrease stability of {self}")
  35. return STABILITY_ORDER[index - 1]
  36. STABILITY_ORDER = [Stability.ALPHA, Stability.BETA, Stability.STABLE]
  37. DESCRIPTIONS = {
  38. Stability.STABLE: "The stable API version.",
  39. Stability.BETA: "Backwards incompatible changes will be announced beforehand.",
  40. Stability.ALPHA: "May get backwards incompatible changes without warning.",
  41. }
  42. class APIVersion(ValueObject):
  43. version: int
  44. stability: Stability
  45. @property
  46. def prefix(self) -> str:
  47. result = f"v{self.version}"
  48. if self.stability is not Stability.STABLE:
  49. result += f"-{self.stability.value}"
  50. return result
  51. @property
  52. def description(self) -> str:
  53. return self.stability.description
  54. def decrease_stability(self) -> "APIVersion":
  55. return APIVersion(version=self.version, stability=self.stability.decrease())
  56. def http_method(path: str, **route_options):
  57. def wrapper(unbound_method: Callable[..., Any]):
  58. setattr(
  59. unbound_method,
  60. "http_method",
  61. (path, route_options),
  62. )
  63. return unbound_method
  64. return wrapper
  65. def v(version: int, stability: str = "stable") -> APIVersion:
  66. return APIVersion(version=version, stability=Stability(stability))
  67. get = partial(http_method, methods=["GET"])
  68. post = partial(http_method, methods=["POST"])
  69. put = partial(http_method, methods=["PUT"])
  70. patch = partial(http_method, methods=["PATCH"])
  71. delete = partial(http_method, methods=["DELETE"])
  72. class OpenApiTag(ValueObject):
  73. name: str
  74. description: Optional[str]
  75. class Resource:
  76. version: APIVersion
  77. name: str
  78. def __init_subclass__(cls, version: APIVersion, name: str = ""):
  79. cls.version = version
  80. cls.name = name
  81. super().__init_subclass__()
  82. @classmethod
  83. def with_version(cls, version: APIVersion) -> Type["Resource"]:
  84. class DynamicResource(cls, version=version, name=cls.name): # type: ignore
  85. pass
  86. DynamicResource.__doc__ = cls.__doc__
  87. return DynamicResource
  88. def get_less_stable(self, resources: Dict[APIVersion, "Resource"]) -> "Resource":
  89. """Fetch a less stable version of this resource from 'resources'
  90. If it doesn't exist, create it dynamically.
  91. """
  92. less_stable_version = self.version.decrease_stability()
  93. # Fetch the less stable resource; generate it if it does not exist
  94. try:
  95. less_stable_resource = resources[less_stable_version]
  96. except KeyError:
  97. less_stable_resource = self.__class__.with_version(less_stable_version)()
  98. # Validate the less stable version
  99. if less_stable_resource.__class__.__bases__ != (self.__class__,):
  100. raise RuntimeError(
  101. f"{less_stable_resource} should be a direct subclass of {self}"
  102. )
  103. return less_stable_resource
  104. def _endpoints(self):
  105. for attr_name in dir(self):
  106. if attr_name.startswith("_"):
  107. continue
  108. endpoint = getattr(self, attr_name)
  109. if not hasattr(endpoint, "http_method"):
  110. continue
  111. yield endpoint
  112. def get_openapi_tag(self) -> OpenApiTag:
  113. return OpenApiTag(
  114. name=self.name,
  115. description=self.__class__.__doc__,
  116. )
  117. def get_router(
  118. self, version: APIVersion, responses: Optional[Dict[str, Dict[str, Any]]] = None
  119. ) -> APIRouter:
  120. assert version == self.version
  121. router = APIRouter()
  122. operation_ids = set()
  123. for endpoint in self._endpoints():
  124. path, route_options = endpoint.http_method
  125. operation_id = endpoint.__name__
  126. if operation_id in operation_ids:
  127. raise RuntimeError(
  128. "Multiple operations {operation_id} configured in {self}"
  129. )
  130. operation_ids.add(operation_id)
  131. # The 'name' is used for reverse lookups (request.path_for): include the
  132. # version prefix so that we can uniquely refer to an operation.
  133. name = version.prefix + "/" + endpoint.__name__
  134. router.add_api_route(
  135. path,
  136. endpoint,
  137. tags=[self.name],
  138. operation_id=endpoint.__name__,
  139. name=name,
  140. responses=responses,
  141. **route_options,
  142. )
  143. return router
  144. def clean_resources_same_name(resources: List[Resource]) -> List[Resource]:
  145. dct = {x.version: x for x in resources}
  146. if len(dct) != len(resources):
  147. raise RuntimeError(
  148. f"Resource with name {resources[0].name} "
  149. f"is defined multiple times with the same version."
  150. )
  151. for stability in [Stability.STABLE, Stability.BETA]:
  152. tmp_resources = {k: v for (k, v) in dct.items() if k.stability is stability}
  153. for version, resource in tmp_resources.items():
  154. dct[version.decrease_stability()] = resource.get_less_stable(dct)
  155. return list(dct.values())
  156. def clean_resources(resources: Sequence[Resource]) -> List[Resource]:
  157. """Ensure that resources are consistent:
  158. - ordered by name
  159. - (tag, version) combinations should be unique
  160. - for stable resources, beta & alpha are autocreated if needed
  161. - for beta resources, alpha is autocreated if needed
  162. """
  163. result = []
  164. names = {x.name for x in resources}
  165. for name in sorted(names):
  166. result.extend(
  167. clean_resources_same_name([x for x in resources if x.name == name])
  168. )
  169. return result