| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 | import asyncioimport refrom http import HTTPStatusfrom io import BytesIOfrom typing import Anyfrom typing import Awaitablefrom typing import Callablefrom typing import Dictfrom typing import Optionalfrom urllib.parse import quotefrom urllib.parse import urlencodefrom urllib.parse import urljoinimport aiohttpfrom aiohttp import ClientResponsefrom aiohttp import ClientSessionfrom pydantic import AnyHttpUrlfrom pydantic import field_validatorfrom clean_python import Jsonfrom clean_python import ValueObjectfrom .exceptions import ApiExceptionfrom .response import Response__all__ = ["ApiProvider", "FileFormPost"]RETRY_STATUSES = frozenset({413, 429, 503})  # like in urllib3def is_success(status: HTTPStatus) -> bool:    """Returns True on 2xx status"""    return (int(status) // 100) == 2JSON_CONTENT_TYPE_REGEX = re.compile(r"^application\/[^+]*[+]?(json);?.*$")def is_json_content_type(content_type: Optional[str]) -> bool:    if not content_type:        return False    return bool(JSON_CONTENT_TYPE_REGEX.match(content_type))def join(url: str, path: str, trailing_slash: bool = False) -> str:    """Results in a full url without trailing slash"""    assert url.endswith("/")    assert not path.startswith("/")    result = urljoin(url, path)    if trailing_slash and not result.endswith("/"):        result = result + "/"    elif not trailing_slash and result.endswith("/"):        result = result[:-1]    return resultdef add_query_params(url: str, params: Optional[Json]) -> str:    if params is None:        return url    return url + "?" + urlencode(params, doseq=True)class FileFormPost(ValueObject):    file_name: str    file: Any  # typing of BinaryIO / BytesIO is hard!    field_name: str = "file"    content_type: str = "application/octet-stream"    @field_validator("file")    @classmethod    def validate_file(cls, v):        if isinstance(v, bytes):            return BytesIO(v)        assert hasattr(v, "read")  # poor-mans BinaryIO validation        return vclass ApiProvider:    """Basic JSON API provider with retry policy and bearer tokens.    The default retry policy has 3 retries with 1, 2, 4 second intervals.    Args:        url: The url of the API (with trailing slash)        fetch_token: Coroutine that returns headers for authorization        retries: Total number of retries per request        backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)    """    def __init__(        self,        url: AnyHttpUrl,        fetch_token: Callable[[], Awaitable[Dict[str, str]]],        retries: int = 3,        backoff_factor: float = 1.0,        trailing_slash: bool = False,    ):        self._url = str(url)        if not self._url.endswith("/"):            self._url += "/"        self._fetch_token = fetch_token        assert retries > 0        self._retries = retries        self._backoff_factor = backoff_factor        self._trailing_slash = trailing_slash        self._session = ClientSession()    async def _request_with_retry(        self,        method: str,        path: str,        params: Optional[Json],        json: Optional[Json],        fields: Optional[Json],        file: Optional[FileFormPost],        timeout: float,    ) -> ClientResponse:        if file is not None:            raise NotImplementedError("ApiProvider doesn't yet support file uploads")        request_kwargs = {            "method": method,            "url": add_query_params(                join(self._url, quote(path), self._trailing_slash), params            ),            "timeout": timeout,            "json": json,            "data": fields,            "headers": await self._fetch_token(),        }        for attempt in range(self._retries):            if attempt > 0:                backoff = self._backoff_factor * 2 ** (attempt - 1)                await asyncio.sleep(backoff)            try:                response = await self._session.request(**request_kwargs)                await response.read()            except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):                if attempt == self._retries - 1:                    raise  # propagate ClientError in case no retries left            else:                if response.status not in RETRY_STATUSES:                    return response  # on all non-retry statuses: return response        return response  # retries exceeded; return the (possibly error) response    async def request(        self,        method: str,        path: str,        params: Optional[Json] = None,        json: Optional[Json] = None,        fields: Optional[Json] = None,        file: Optional[FileFormPost] = None,        timeout: float = 5.0,    ) -> Optional[Json]:        response = await self._request_with_retry(            method, path, params, json, fields, file, timeout        )        status = HTTPStatus(response.status)        content_type = response.headers.get("Content-Type")        if status is HTTPStatus.NO_CONTENT:            return None        if not is_json_content_type(content_type):            raise ApiException(                f"Unexpected content type '{content_type}'", status=status            )        body = await response.json()        if is_success(status):            return body        else:            raise ApiException(body, status=status)    async def request_raw(        self,        method: str,        path: str,        params: Optional[Json] = None,        json: Optional[Json] = None,        fields: Optional[Json] = None,        file: Optional[FileFormPost] = None,        timeout: float = 5.0,    ) -> Response:        response = await self._request_with_retry(            method, path, params, json, fields, file, timeout        )        return Response(            status=response.status,            data=await response.read(),            content_type=response.headers.get("Content-Type"),        )
 |