api_provider.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from http import HTTPStatus
  2. from typing import Callable
  3. from typing import Optional
  4. from urllib.parse import urlencode
  5. from urllib.parse import urljoin
  6. from pydantic import AnyHttpUrl
  7. from urllib3 import PoolManager
  8. from urllib3 import Retry
  9. from clean_python import ctx
  10. from clean_python import Json
  11. from .exceptions import ApiException
  12. __all__ = ["SyncApiProvider"]
  13. def is_success(status: HTTPStatus) -> bool:
  14. """Returns True on 2xx status"""
  15. return (int(status) // 100) == 2
  16. def join(url: str, path: str) -> str:
  17. """Results in a full url without trailing slash"""
  18. assert url.endswith("/")
  19. assert not path.startswith("/")
  20. result = urljoin(url, path)
  21. if result.endswith("/"):
  22. result = result[:-1]
  23. return result
  24. def add_query_params(url: str, params: Optional[Json]) -> str:
  25. if params is None:
  26. return url
  27. return url + "?" + urlencode(params, doseq=True)
  28. class SyncApiProvider:
  29. """Basic JSON API provider with retry policy and bearer tokens.
  30. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  31. Args:
  32. url: The url of the API (with trailing slash)
  33. fetch_token: Callable that returns a token for a tenant id
  34. retries: Total number of retries per request
  35. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  36. """
  37. def __init__(
  38. self,
  39. url: AnyHttpUrl,
  40. fetch_token: Callable[[PoolManager, int], Optional[str]],
  41. retries: int = 3,
  42. backoff_factor: float = 1.0,
  43. ):
  44. self._url = str(url)
  45. assert self._url.endswith("/")
  46. self._fetch_token = fetch_token
  47. self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
  48. def request(
  49. self,
  50. method: str,
  51. path: str,
  52. params: Optional[Json] = None,
  53. json: Optional[Json] = None,
  54. fields: Optional[Json] = None,
  55. timeout: float = 5.0,
  56. ) -> Optional[Json]:
  57. assert ctx.tenant is not None
  58. url = join(self._url, path)
  59. token = self._fetch_token(self._pool, ctx.tenant.id)
  60. headers = {}
  61. if token is not None:
  62. headers["Authorization"] = f"Bearer {token}"
  63. response = self._pool.request(
  64. method=method,
  65. url=add_query_params(url, params),
  66. json=json,
  67. fields=fields,
  68. headers=headers,
  69. timeout=timeout,
  70. )
  71. status = HTTPStatus(response.status)
  72. content_type = response.headers.get("Content-Type")
  73. if content_type is None and status is HTTPStatus.NO_CONTENT:
  74. return {"status": int(status)} # we have to return something...
  75. if content_type != "application/json":
  76. raise ApiException(
  77. f"Unexpected content type '{content_type}'", status=status
  78. )
  79. body = response.json()
  80. if status is HTTPStatus.NOT_FOUND:
  81. return None
  82. elif is_success(status):
  83. return body
  84. else:
  85. raise ApiException(body, status=status)