api_provider.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import asyncio
  2. import re
  3. from http import HTTPStatus
  4. from io import BytesIO
  5. from typing import Any
  6. from typing import Awaitable
  7. from typing import Callable
  8. from typing import Dict
  9. from typing import Optional
  10. from urllib.parse import quote
  11. from urllib.parse import urlencode
  12. from urllib.parse import urljoin
  13. import aiohttp
  14. from aiohttp import ClientResponse
  15. from aiohttp import ClientSession
  16. from pydantic import AnyHttpUrl
  17. from pydantic import field_validator
  18. from clean_python import Conflict
  19. from clean_python import Json
  20. from clean_python import ValueObject
  21. from .exceptions import ApiException
  22. from .response import Response
  23. __all__ = ["ApiProvider", "FileFormPost"]
  24. # Retry on 429 and all 5xx errors (because they are mostly temporary)
  25. RETRY_STATUSES = frozenset(
  26. {
  27. HTTPStatus.TOO_MANY_REQUESTS,
  28. HTTPStatus.INTERNAL_SERVER_ERROR,
  29. HTTPStatus.BAD_GATEWAY,
  30. HTTPStatus.SERVICE_UNAVAILABLE,
  31. HTTPStatus.GATEWAY_TIMEOUT,
  32. }
  33. )
  34. # PATCH is strictly not idempotent, because you could do advanced
  35. # JSON operations like 'add an array element'. mostly idempotent.
  36. # However we never do that and we always make PATCH idempotent.
  37. RETRY_METHODS = frozenset(["HEAD", "GET", "PATCH", "PUT", "DELETE", "OPTIONS", "TRACE"])
  38. def is_success(status: HTTPStatus) -> bool:
  39. """Returns True on 2xx status"""
  40. return (int(status) // 100) == 2
  41. def check_exception(status: HTTPStatus, body: Json) -> None:
  42. if status == HTTPStatus.CONFLICT:
  43. raise Conflict(body.get("message", str(body)))
  44. elif not is_success(status):
  45. raise ApiException(body, status=status)
  46. JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")
  47. def is_json_content_type(content_type: Optional[str]) -> bool:
  48. if not content_type:
  49. return False
  50. return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))
  51. def join(url: str, path: str, trailing_slash: bool = False) -> str:
  52. """Results in a full url without trailing slash"""
  53. assert url.endswith("/")
  54. assert not path.startswith("/")
  55. result = urljoin(url, path)
  56. if trailing_slash and not result.endswith("/"):
  57. result = result + "/"
  58. elif not trailing_slash and result.endswith("/"):
  59. result = result[:-1]
  60. return result
  61. def add_query_params(url: str, params: Optional[Json]) -> str:
  62. if params is None:
  63. return url
  64. return url + "?" + urlencode(params, doseq=True)
  65. class FileFormPost(ValueObject):
  66. file_name: str
  67. file: Any # typing of BinaryIO / BytesIO is hard!
  68. field_name: str = "file"
  69. content_type: str = "application/octet-stream"
  70. @field_validator("file")
  71. @classmethod
  72. def validate_file(cls, v):
  73. if isinstance(v, bytes):
  74. return BytesIO(v)
  75. assert hasattr(v, "read") # poor-mans BinaryIO validation
  76. return v
  77. class ApiProvider:
  78. """Basic JSON API provider with retry policy and bearer tokens.
  79. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  80. Args:
  81. url: The url of the API (with trailing slash)
  82. headers_factory: Coroutine that returns headers (for e.g. authorization)
  83. retries: Total number of retries per request
  84. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  85. trailing_slash: Wether to automatically add or remove trailing slashes.
  86. """
  87. def __init__(
  88. self,
  89. url: AnyHttpUrl,
  90. headers_factory: Optional[Callable[[], Awaitable[Dict[str, str]]]] = None,
  91. retries: int = 3,
  92. backoff_factor: float = 1.0,
  93. trailing_slash: bool = False,
  94. ):
  95. self._url = str(url)
  96. if not self._url.endswith("/"):
  97. self._url += "/"
  98. self._headers_factory = headers_factory
  99. assert retries >= 0
  100. self._retries = retries
  101. self._backoff_factor = backoff_factor
  102. self._trailing_slash = trailing_slash
  103. @property
  104. def _session(self) -> ClientSession:
  105. # There seems to be an issue if the ClientSession is instantiated before
  106. # the event loop runs. So we do that delayed in a property. Use this property
  107. # in a context manager.
  108. # TODO It is more efficient to reuse the connection / connection pools. One idea
  109. # is to expose .session as a context manager (like with the SQLProvider.transaction)
  110. return ClientSession()
  111. async def _request_with_retry(
  112. self,
  113. method: str,
  114. path: str,
  115. params: Optional[Json],
  116. json: Optional[Json],
  117. fields: Optional[Json],
  118. file: Optional[FileFormPost],
  119. headers: Optional[Dict[str, str]],
  120. timeout: float,
  121. ) -> ClientResponse:
  122. if file is not None:
  123. raise NotImplementedError("ApiProvider doesn't yet support file uploads")
  124. request_kwargs = {
  125. "method": method,
  126. "url": add_query_params(
  127. join(self._url, quote(path), self._trailing_slash), params
  128. ),
  129. "timeout": timeout,
  130. "json": json,
  131. "data": fields,
  132. }
  133. actual_headers = {}
  134. if self._headers_factory is not None:
  135. actual_headers.update(await self._headers_factory())
  136. if headers:
  137. actual_headers.update(headers)
  138. retries = self._retries if method.upper() in RETRY_METHODS else 0
  139. for attempt in range(retries + 1):
  140. if attempt > 0:
  141. backoff = self._backoff_factor * 2 ** (attempt - 1)
  142. await asyncio.sleep(backoff)
  143. try:
  144. async with self._session as session:
  145. response = await session.request(
  146. headers=actual_headers, **request_kwargs
  147. )
  148. if response.status in RETRY_STATUSES:
  149. continue
  150. await response.read()
  151. return response
  152. except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
  153. if attempt == retries:
  154. raise # propagate ClientError in case no retries left
  155. return response # retries exceeded; return the (possibly error) response
  156. async def request(
  157. self,
  158. method: str,
  159. path: str,
  160. params: Optional[Json] = None,
  161. json: Optional[Json] = None,
  162. fields: Optional[Json] = None,
  163. file: Optional[FileFormPost] = None,
  164. headers: Optional[Dict[str, str]] = None,
  165. timeout: float = 5.0,
  166. ) -> Optional[Json]:
  167. response = await self._request_with_retry(
  168. method, path, params, json, fields, file, headers, timeout
  169. )
  170. status = HTTPStatus(response.status)
  171. content_type = response.headers.get("Content-Type")
  172. if status is HTTPStatus.NO_CONTENT:
  173. return None
  174. if not is_json_content_type(content_type):
  175. raise ApiException(
  176. f"Unexpected content type '{content_type}'", status=status
  177. )
  178. body = await response.json()
  179. check_exception(status, body)
  180. return body
  181. async def request_raw(
  182. self,
  183. method: str,
  184. path: str,
  185. params: Optional[Json] = None,
  186. json: Optional[Json] = None,
  187. fields: Optional[Json] = None,
  188. file: Optional[FileFormPost] = None,
  189. headers: Optional[Dict[str, str]] = None,
  190. timeout: float = 5.0,
  191. ) -> Response:
  192. response = await self._request_with_retry(
  193. method, path, params, json, fields, file, headers, timeout
  194. )
  195. return Response(
  196. status=response.status,
  197. data=await response.read(),
  198. content_type=response.headers.get("Content-Type"),
  199. )