resource.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. from enum import Enum
  2. from functools import partial
  3. from typing import Any, Callable, Dict, List, Optional, Sequence, Type
  4. from fastapi.routing import APIRouter
  5. from clean_python.base.domain.value_object import ValueObject
  6. __all__ = [
  7. "Resource",
  8. "get",
  9. "post",
  10. "put",
  11. "patch",
  12. "delete",
  13. "APIVersion",
  14. "Stability",
  15. "v",
  16. "clean_resources",
  17. ]
  18. class Stability(str, Enum):
  19. STABLE = "stable"
  20. BETA = "beta"
  21. ALPHA = "alpha"
  22. @property
  23. def description(self) -> str:
  24. return DESCRIPTIONS[self]
  25. def decrease(self) -> "Stability":
  26. index = STABILITY_ORDER.index(self)
  27. if index == 0:
  28. raise ValueError(f"Cannot decrease stability of {self}")
  29. return STABILITY_ORDER[index - 1]
  30. STABILITY_ORDER = [Stability.ALPHA, Stability.BETA, Stability.STABLE]
  31. DESCRIPTIONS = {
  32. Stability.STABLE: "The stable API version.",
  33. Stability.BETA: "Backwards incompatible changes will be announced beforehand.",
  34. Stability.ALPHA: "May get backwards incompatible changes without warning.",
  35. }
  36. class APIVersion(ValueObject):
  37. version: int
  38. stability: Stability
  39. @property
  40. def prefix(self) -> str:
  41. result = f"v{self.version}"
  42. if self.stability is not Stability.STABLE:
  43. result += f"-{self.stability.value}"
  44. return result
  45. @property
  46. def description(self) -> str:
  47. return self.stability.description
  48. def decrease_stability(self) -> "APIVersion":
  49. return APIVersion(version=self.version, stability=self.stability.decrease())
  50. def http_method(path: str, **route_options):
  51. def wrapper(unbound_method: Callable[..., Any]):
  52. setattr(
  53. unbound_method,
  54. "http_method",
  55. (path, route_options),
  56. )
  57. return unbound_method
  58. return wrapper
  59. def v(version: int, stability: str = "stable") -> APIVersion:
  60. return APIVersion(version=version, stability=Stability(stability))
  61. get = partial(http_method, methods=["GET"])
  62. post = partial(http_method, methods=["POST"])
  63. put = partial(http_method, methods=["PUT"])
  64. patch = partial(http_method, methods=["PATCH"])
  65. delete = partial(http_method, methods=["DELETE"])
  66. class OpenApiTag(ValueObject):
  67. name: str
  68. description: Optional[str]
  69. class Resource:
  70. version: APIVersion
  71. name: str
  72. def __init_subclass__(cls, version: APIVersion, name: str = ""):
  73. cls.version = version
  74. cls.name = name
  75. super().__init_subclass__()
  76. @classmethod
  77. def with_version(cls, version: APIVersion) -> Type["Resource"]:
  78. class DynamicResource(cls, version=version, name=cls.name): # type: ignore
  79. pass
  80. DynamicResource.__doc__ = cls.__doc__
  81. return DynamicResource
  82. def get_less_stable(self, resources: Dict[APIVersion, "Resource"]) -> "Resource":
  83. """Fetch a less stable version of this resource from 'resources'
  84. If it doesn't exist, create it dynamically.
  85. """
  86. less_stable_version = self.version.decrease_stability()
  87. # Fetch the less stable resource; generate it if it does not exist
  88. try:
  89. less_stable_resource = resources[less_stable_version]
  90. except KeyError:
  91. less_stable_resource = self.__class__.with_version(less_stable_version)()
  92. # Validate the less stable version
  93. if less_stable_resource.__class__.__bases__ != (self.__class__,):
  94. raise RuntimeError(
  95. f"{less_stable_resource} should be a direct subclass of {self}"
  96. )
  97. return less_stable_resource
  98. def _endpoints(self):
  99. for attr_name in dir(self):
  100. if attr_name.startswith("_"):
  101. continue
  102. endpoint = getattr(self, attr_name)
  103. if not hasattr(endpoint, "http_method"):
  104. continue
  105. yield endpoint
  106. def get_openapi_tag(self) -> OpenApiTag:
  107. return OpenApiTag(
  108. name=self.name,
  109. description=self.__class__.__doc__,
  110. )
  111. def get_router(
  112. self, version: APIVersion, responses: Optional[Dict[str, Dict[str, Any]]] = None
  113. ) -> APIRouter:
  114. assert version == self.version
  115. router = APIRouter()
  116. operation_ids = set()
  117. for endpoint in self._endpoints():
  118. path, route_options = endpoint.http_method
  119. operation_id = endpoint.__name__
  120. if operation_id in operation_ids:
  121. raise RuntimeError(
  122. "Multiple operations {operation_id} configured in {self}"
  123. )
  124. operation_ids.add(operation_id)
  125. # The 'name' is used for reverse lookups (request.path_for): include the
  126. # version prefix so that we can uniquely refer to an operation.
  127. name = version.prefix + "/" + endpoint.__name__
  128. router.add_api_route(
  129. path,
  130. endpoint,
  131. tags=[self.name],
  132. operation_id=endpoint.__name__,
  133. name=name,
  134. responses=responses,
  135. **route_options,
  136. )
  137. return router
  138. def clean_resources_same_name(resources: List[Resource]) -> List[Resource]:
  139. dct = {x.version: x for x in resources}
  140. if len(dct) != len(resources):
  141. raise RuntimeError(
  142. f"Resource with name {resources[0].name} "
  143. f"is defined multiple times with the same version."
  144. )
  145. for stability in [Stability.STABLE, Stability.BETA]:
  146. tmp_resources = {k: v for (k, v) in dct.items() if k.stability is stability}
  147. for version, resource in tmp_resources.items():
  148. dct[version.decrease_stability()] = resource.get_less_stable(dct)
  149. return list(dct.values())
  150. def clean_resources(resources: Sequence[Resource]) -> List[Resource]:
  151. """Ensure that resources are consistent:
  152. - ordered by name
  153. - (tag, version) combinations should be unique
  154. - for stable resources, beta & alpha are autocreated if needed
  155. - for beta resources, alpha is autocreated if needed
  156. """
  157. result = []
  158. names = {x.name for x in resources}
  159. for name in sorted(names):
  160. result.extend(
  161. clean_resources_same_name([x for x in resources if x.name == name])
  162. )
  163. return result