api_provider.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import asyncio
  2. import re
  3. from http import HTTPStatus
  4. from typing import Awaitable
  5. from typing import Callable
  6. from typing import Dict
  7. from typing import Optional
  8. from urllib.parse import quote
  9. from urllib.parse import urlencode
  10. from urllib.parse import urljoin
  11. import aiohttp
  12. from aiohttp import ClientResponse
  13. from aiohttp import ClientSession
  14. from pydantic import AnyHttpUrl
  15. from clean_python import Json
  16. from .exceptions import ApiException
  17. from .response import Response
  18. __all__ = ["ApiProvider"]
  19. RETRY_STATUSES = frozenset({413, 429, 503}) # like in urllib3
  20. def is_success(status: HTTPStatus) -> bool:
  21. """Returns True on 2xx status"""
  22. return (int(status) // 100) == 2
  23. JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")
  24. def is_json_content_type(content_type: Optional[str]) -> bool:
  25. if not content_type:
  26. return False
  27. return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))
  28. def join(url: str, path: str) -> str:
  29. """Results in a full url without trailing slash"""
  30. assert url.endswith("/")
  31. assert not path.startswith("/")
  32. result = urljoin(url, path)
  33. if result.endswith("/"):
  34. result = result[:-1]
  35. return result
  36. def add_query_params(url: str, params: Optional[Json]) -> str:
  37. if params is None:
  38. return url
  39. return url + "?" + urlencode(params, doseq=True)
  40. class ApiProvider:
  41. """Basic JSON API provider with retry policy and bearer tokens.
  42. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  43. Args:
  44. url: The url of the API (with trailing slash)
  45. fetch_token: Coroutine that returns headers for authorization
  46. retries: Total number of retries per request
  47. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  48. """
  49. def __init__(
  50. self,
  51. url: AnyHttpUrl,
  52. fetch_token: Callable[[], Awaitable[Dict[str, str]]],
  53. retries: int = 3,
  54. backoff_factor: float = 1.0,
  55. ):
  56. self._url = str(url)
  57. if not self._url.endswith("/"):
  58. self._url += "/"
  59. self._fetch_token = fetch_token
  60. assert retries > 0
  61. self._retries = retries
  62. self._backoff_factor = backoff_factor
  63. self._session = ClientSession()
  64. async def _request_with_retry(
  65. self,
  66. method: str,
  67. path: str,
  68. params: Optional[Json],
  69. json: Optional[Json],
  70. fields: Optional[Json],
  71. timeout: float,
  72. ) -> ClientResponse:
  73. request_kwargs = {
  74. "method": method,
  75. "url": add_query_params(join(self._url, quote(path)), params),
  76. "timeout": timeout,
  77. "json": json,
  78. "data": fields,
  79. "headers": await self._fetch_token(),
  80. }
  81. for attempt in range(self._retries):
  82. if attempt > 0:
  83. backoff = self._backoff_factor * 2 ** (attempt - 1)
  84. await asyncio.sleep(backoff)
  85. try:
  86. response = await self._session.request(**request_kwargs)
  87. await response.read()
  88. except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
  89. if attempt == self._retries - 1:
  90. raise # propagate ClientError in case no retries left
  91. else:
  92. if response.status not in RETRY_STATUSES:
  93. return response # on all non-retry statuses: return response
  94. return response # retries exceeded; return the (possibly error) response
  95. async def request(
  96. self,
  97. method: str,
  98. path: str,
  99. params: Optional[Json] = None,
  100. json: Optional[Json] = None,
  101. fields: Optional[Json] = None,
  102. timeout: float = 5.0,
  103. ) -> Optional[Json]:
  104. response = await self._request_with_retry(
  105. method, path, params, json, fields, timeout
  106. )
  107. status = HTTPStatus(response.status)
  108. content_type = response.headers.get("Content-Type")
  109. if status is HTTPStatus.NO_CONTENT:
  110. return None
  111. if not is_json_content_type(content_type):
  112. raise ApiException(
  113. f"Unexpected content type '{content_type}'", status=status
  114. )
  115. body = await response.json()
  116. if is_success(status):
  117. return body
  118. else:
  119. raise ApiException(body, status=status)
  120. async def request_raw(
  121. self,
  122. method: str,
  123. path: str,
  124. params: Optional[Json] = None,
  125. json: Optional[Json] = None,
  126. fields: Optional[Json] = None,
  127. timeout: float = 5.0,
  128. ) -> Response:
  129. response = await self._request_with_retry(
  130. method, path, params, json, fields, timeout
  131. )
  132. return Response(
  133. status=response.status,
  134. data=await response.read(),
  135. content_type=response.headers.get("Content-Type"),
  136. )