api_provider.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import asyncio
  2. import re
  3. from http import HTTPStatus
  4. from typing import Callable
  5. from typing import Optional
  6. from urllib.parse import quote
  7. from urllib.parse import urlencode
  8. from urllib.parse import urljoin
  9. import aiohttp
  10. from aiohttp import ClientResponse
  11. from aiohttp import ClientSession
  12. from pydantic import AnyHttpUrl
  13. from clean_python import ctx
  14. from clean_python import Json
  15. from .exceptions import ApiException
  16. from .response import Response
  17. __all__ = ["ApiProvider"]
  18. RETRY_STATUSES = frozenset({413, 429, 503}) # like in urllib3
  19. def is_success(status: HTTPStatus) -> bool:
  20. """Returns True on 2xx status"""
  21. return (int(status) // 100) == 2
  22. JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")
  23. def is_json_content_type(content_type: Optional[str]) -> bool:
  24. if not content_type:
  25. return False
  26. return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))
  27. def join(url: str, path: str) -> str:
  28. """Results in a full url without trailing slash"""
  29. assert url.endswith("/")
  30. assert not path.startswith("/")
  31. result = urljoin(url, path)
  32. if result.endswith("/"):
  33. result = result[:-1]
  34. return result
  35. def add_query_params(url: str, params: Optional[Json]) -> str:
  36. if params is None:
  37. return url
  38. return url + "?" + urlencode(params, doseq=True)
  39. class ApiProvider:
  40. """Basic JSON API provider with retry policy and bearer tokens.
  41. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  42. Args:
  43. url: The url of the API (with trailing slash)
  44. fetch_token: Callable that returns a token for a tenant id
  45. retries: Total number of retries per request
  46. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  47. """
  48. def __init__(
  49. self,
  50. url: AnyHttpUrl,
  51. fetch_token: Callable[[ClientSession, int], Optional[str]],
  52. retries: int = 3,
  53. backoff_factor: float = 1.0,
  54. ):
  55. self._url = str(url)
  56. assert self._url.endswith("/")
  57. self._fetch_token = fetch_token
  58. assert retries > 0
  59. self._retries = retries
  60. self._backoff_factor = backoff_factor
  61. self._session = ClientSession()
  62. async def _request_with_retry(
  63. self,
  64. method: str,
  65. path: str,
  66. params: Optional[Json],
  67. json: Optional[Json],
  68. fields: Optional[Json],
  69. timeout: float,
  70. ) -> ClientResponse:
  71. assert ctx.tenant is not None
  72. headers = {}
  73. request_kwargs = {
  74. "method": method,
  75. "url": add_query_params(join(self._url, quote(path)), params),
  76. "timeout": timeout,
  77. "json": json,
  78. "data": fields,
  79. }
  80. token = self._fetch_token(self._session, ctx.tenant.id)
  81. if token is not None:
  82. headers["Authorization"] = f"Bearer {token}"
  83. for attempt in range(self._retries):
  84. if attempt > 0:
  85. backoff = self._backoff_factor * 2 ** (attempt - 1)
  86. await asyncio.sleep(backoff)
  87. try:
  88. response = await self._session.request(
  89. headers=headers, **request_kwargs
  90. )
  91. await response.read()
  92. except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
  93. if attempt == self._retries - 1:
  94. raise # propagate ClientError in case no retries left
  95. else:
  96. if response.status not in RETRY_STATUSES:
  97. return response # on all non-retry statuses: return response
  98. return response # retries exceeded; return the (possibly error) response
  99. async def request(
  100. self,
  101. method: str,
  102. path: str,
  103. params: Optional[Json] = None,
  104. json: Optional[Json] = None,
  105. fields: Optional[Json] = None,
  106. timeout: float = 5.0,
  107. ) -> Optional[Json]:
  108. response = await self._request_with_retry(
  109. method, path, params, json, fields, timeout
  110. )
  111. status = HTTPStatus(response.status)
  112. content_type = response.headers.get("Content-Type")
  113. if status is HTTPStatus.NO_CONTENT:
  114. return None
  115. if not is_json_content_type(content_type):
  116. raise ApiException(
  117. f"Unexpected content type '{content_type}'", status=status
  118. )
  119. body = await response.json()
  120. if is_success(status):
  121. return body
  122. else:
  123. raise ApiException(body, status=status)
  124. async def request_raw(
  125. self,
  126. method: str,
  127. path: str,
  128. params: Optional[Json] = None,
  129. json: Optional[Json] = None,
  130. fields: Optional[Json] = None,
  131. timeout: float = 5.0,
  132. ) -> Response:
  133. response = await self._request_with_retry(
  134. method, path, params, json, fields, timeout
  135. )
  136. return Response(
  137. status=response.status,
  138. data=await response.read(),
  139. content_type=response.headers.get("Content-Type"),
  140. )