api_provider.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import json as json_lib
  2. import re
  3. from http import HTTPStatus
  4. from typing import Callable
  5. from typing import Optional
  6. from urllib.parse import urlencode
  7. from urllib.parse import urljoin
  8. from pydantic import AnyHttpUrl
  9. from urllib3 import PoolManager
  10. from urllib3 import Retry
  11. from clean_python import ctx
  12. from clean_python import Json
  13. from .exceptions import ApiException
  14. __all__ = ["SyncApiProvider"]
  15. def is_success(status: HTTPStatus) -> bool:
  16. """Returns True on 2xx status"""
  17. return (int(status) // 100) == 2
  18. JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")
  19. def is_json_content_type(content_type: Optional[str]) -> bool:
  20. if not content_type:
  21. return False
  22. return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))
  23. def join(url: str, path: str) -> str:
  24. """Results in a full url without trailing slash"""
  25. assert url.endswith("/")
  26. assert not path.startswith("/")
  27. result = urljoin(url, path)
  28. if result.endswith("/"):
  29. result = result[:-1]
  30. return result
  31. def add_query_params(url: str, params: Optional[Json]) -> str:
  32. if params is None:
  33. return url
  34. return url + "?" + urlencode(params, doseq=True)
  35. class SyncApiProvider:
  36. """Basic JSON API provider with retry policy and bearer tokens.
  37. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  38. Args:
  39. url: The url of the API (with trailing slash)
  40. fetch_token: Callable that returns a token for a tenant id
  41. retries: Total number of retries per request
  42. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  43. """
  44. def __init__(
  45. self,
  46. url: AnyHttpUrl,
  47. fetch_token: Callable[[PoolManager, int], Optional[str]],
  48. retries: int = 3,
  49. backoff_factor: float = 1.0,
  50. ):
  51. self._url = str(url)
  52. assert self._url.endswith("/")
  53. self._fetch_token = fetch_token
  54. self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
  55. def request(
  56. self,
  57. method: str,
  58. path: str,
  59. params: Optional[Json] = None,
  60. json: Optional[Json] = None,
  61. fields: Optional[Json] = None,
  62. timeout: float = 5.0,
  63. ) -> Optional[Json]:
  64. assert ctx.tenant is not None
  65. headers = {}
  66. request_kwargs = {
  67. "method": method,
  68. "url": add_query_params(join(self._url, path), params),
  69. "timeout": timeout,
  70. }
  71. # for urllib3<2, we dump json ourselves
  72. if json is not None and fields is not None:
  73. raise ValueError("Cannot both specify 'json' and 'fields'")
  74. elif json is not None:
  75. request_kwargs["body"] = json_lib.dumps(json).encode()
  76. headers["Content-Type"] = "application/json"
  77. elif fields is not None:
  78. request_kwargs["fields"] = fields
  79. token = self._fetch_token(self._pool, ctx.tenant.id)
  80. if token is not None:
  81. headers["Authorization"] = f"Bearer {token}"
  82. response = self._pool.request(headers=headers, **request_kwargs)
  83. status = HTTPStatus(response.status)
  84. content_type = response.headers.get("Content-Type")
  85. if status is HTTPStatus.NO_CONTENT:
  86. return None
  87. if not is_json_content_type(content_type):
  88. raise ApiException(
  89. f"Unexpected content type '{content_type}'", status=status
  90. )
  91. body = json_lib.loads(response.data.decode())
  92. if is_success(status):
  93. return body
  94. else:
  95. raise ApiException(body, status=status)