api_provider.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import asyncio
  2. import re
  3. from http import HTTPStatus
  4. from io import BytesIO
  5. from typing import Any
  6. from typing import Awaitable
  7. from typing import Callable
  8. from typing import Dict
  9. from typing import Optional
  10. from urllib.parse import quote
  11. from urllib.parse import urlencode
  12. from urllib.parse import urljoin
  13. import aiohttp
  14. from aiohttp import ClientResponse
  15. from aiohttp import ClientSession
  16. from pydantic import AnyHttpUrl
  17. from pydantic import field_validator
  18. from clean_python import Json
  19. from clean_python import ValueObject
  20. from .exceptions import ApiException
  21. from .response import Response
  22. __all__ = ["ApiProvider", "FileFormPost"]
  23. RETRY_STATUSES = frozenset({413, 429, 503}) # like in urllib3
  24. def is_success(status: HTTPStatus) -> bool:
  25. """Returns True on 2xx status"""
  26. return (int(status) // 100) == 2
  27. JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")
  28. def is_json_content_type(content_type: Optional[str]) -> bool:
  29. if not content_type:
  30. return False
  31. return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))
  32. def join(url: str, path: str, trailing_slash: bool = False) -> str:
  33. """Results in a full url without trailing slash"""
  34. assert url.endswith("/")
  35. assert not path.startswith("/")
  36. result = urljoin(url, path)
  37. if trailing_slash and not result.endswith("/"):
  38. result = result + "/"
  39. elif not trailing_slash and result.endswith("/"):
  40. result = result[:-1]
  41. return result
  42. def add_query_params(url: str, params: Optional[Json]) -> str:
  43. if params is None:
  44. return url
  45. return url + "?" + urlencode(params, doseq=True)
  46. class FileFormPost(ValueObject):
  47. file_name: str
  48. file: Any # typing of BinaryIO / BytesIO is hard!
  49. field_name: str = "file"
  50. content_type: str = "application/octet-stream"
  51. @field_validator("file")
  52. @classmethod
  53. def validate_file(cls, v):
  54. if isinstance(v, bytes):
  55. return BytesIO(v)
  56. assert hasattr(v, "read") # poor-mans BinaryIO validation
  57. return v
  58. class ApiProvider:
  59. """Basic JSON API provider with retry policy and bearer tokens.
  60. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  61. Args:
  62. url: The url of the API (with trailing slash)
  63. fetch_token: Coroutine that returns headers for authorization
  64. retries: Total number of retries per request
  65. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  66. """
  67. def __init__(
  68. self,
  69. url: AnyHttpUrl,
  70. fetch_token: Callable[[], Awaitable[Dict[str, str]]],
  71. retries: int = 3,
  72. backoff_factor: float = 1.0,
  73. trailing_slash: bool = False,
  74. ):
  75. self._url = str(url)
  76. if not self._url.endswith("/"):
  77. self._url += "/"
  78. self._fetch_token = fetch_token
  79. assert retries > 0
  80. self._retries = retries
  81. self._backoff_factor = backoff_factor
  82. self._trailing_slash = trailing_slash
  83. self._session = ClientSession()
  84. async def _request_with_retry(
  85. self,
  86. method: str,
  87. path: str,
  88. params: Optional[Json],
  89. json: Optional[Json],
  90. fields: Optional[Json],
  91. file: Optional[FileFormPost],
  92. timeout: float,
  93. ) -> ClientResponse:
  94. if file is not None:
  95. raise NotImplementedError("ApiProvider doesn't yet support file uploads")
  96. request_kwargs = {
  97. "method": method,
  98. "url": add_query_params(
  99. join(self._url, quote(path), self._trailing_slash), params
  100. ),
  101. "timeout": timeout,
  102. "json": json,
  103. "data": fields,
  104. "headers": await self._fetch_token(),
  105. }
  106. for attempt in range(self._retries):
  107. if attempt > 0:
  108. backoff = self._backoff_factor * 2 ** (attempt - 1)
  109. await asyncio.sleep(backoff)
  110. try:
  111. response = await self._session.request(**request_kwargs)
  112. await response.read()
  113. except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
  114. if attempt == self._retries - 1:
  115. raise # propagate ClientError in case no retries left
  116. else:
  117. if response.status not in RETRY_STATUSES:
  118. return response # on all non-retry statuses: return response
  119. return response # retries exceeded; return the (possibly error) response
  120. async def request(
  121. self,
  122. method: str,
  123. path: str,
  124. params: Optional[Json] = None,
  125. json: Optional[Json] = None,
  126. fields: Optional[Json] = None,
  127. file: Optional[FileFormPost] = None,
  128. timeout: float = 5.0,
  129. ) -> Optional[Json]:
  130. response = await self._request_with_retry(
  131. method, path, params, json, fields, file, timeout
  132. )
  133. status = HTTPStatus(response.status)
  134. content_type = response.headers.get("Content-Type")
  135. if status is HTTPStatus.NO_CONTENT:
  136. return None
  137. if not is_json_content_type(content_type):
  138. raise ApiException(
  139. f"Unexpected content type '{content_type}'", status=status
  140. )
  141. body = await response.json()
  142. if is_success(status):
  143. return body
  144. else:
  145. raise ApiException(body, status=status)
  146. async def request_raw(
  147. self,
  148. method: str,
  149. path: str,
  150. params: Optional[Json] = None,
  151. json: Optional[Json] = None,
  152. fields: Optional[Json] = None,
  153. file: Optional[FileFormPost] = None,
  154. timeout: float = 5.0,
  155. ) -> Response:
  156. response = await self._request_with_retry(
  157. method, path, params, json, fields, file, timeout
  158. )
  159. return Response(
  160. status=response.status,
  161. data=await response.read(),
  162. content_type=response.headers.get("Content-Type"),
  163. )