123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417 |
- import base64
- import hashlib
- import logging
- import os
- import re
- from http import HTTPStatus
- from pathlib import Path
- from typing import BinaryIO
- from typing import Callable
- from typing import Dict
- from typing import Optional
- from typing import Tuple
- from typing import Union
- from urllib.parse import urlparse
- import urllib3
- from .exceptions import ApiException
- __all__ = ["download_file", "download_fileobj", "upload_file", "upload_fileobj"]
- CONTENT_RANGE_REGEXP = re.compile(r"^bytes (\d+)-(\d+)/(\d+|\*)$")
- # Default upload timeout has an increased socket read timeout, because MinIO
- # takes very long for completing the upload for larger files. The limit of 10 minutes
- # should accomodate files up to 150 GB.
- DEFAULT_UPLOAD_TIMEOUT = urllib3.Timeout(connect=5.0, read=600.0)
- logger = logging.getLogger(__name__)
- def get_pool(retries: int = 3, backoff_factor: float = 1.0) -> urllib3.PoolManager:
- """Create a PoolManager with a retry policy.
- The default retry policy has 3 retries with 1, 2, 4 second intervals.
- Args:
- retries: Total number of retries per request
- backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
- """
- return urllib3.PoolManager(
- retries=urllib3.util.retry.Retry(retries, backoff_factor=backoff_factor)
- )
- def compute_md5(fileobj: BinaryIO, chunk_size: int = 16777216):
- """Compute the MD5 checksum of a file object."""
- fileobj.seek(0)
- hasher = hashlib.md5()
- for chunk in _iter_chunks(fileobj, chunk_size=chunk_size):
- hasher.update(chunk)
- return hasher.digest()
- def download_file(
- url: str,
- target: Path,
- chunk_size: int = 16777216,
- timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
- pool: Optional[urllib3.PoolManager] = None,
- callback_func: Optional[Callable[[int, int], None]] = None,
- headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
- ) -> Tuple[Path, int]:
- """Download a file to a specified path on disk.
- It is assumed that the file server supports multipart downloads (range
- requests).
- Args:
- url: The url to retrieve.
- target: The location to copy to. If this is an existing file, it is
- overwritten. If it is a directory, a filename is generated from
- the filename in the url.
- chunk_size: The number of bytes per request. Default: 16MB.
- timeout: The total timeout in seconds.
- pool: If not supplied, a default connection pool will be
- created with a retry policy of 3 retries after 1, 2, 4 seconds.
- callback_func: optional function used to receive: bytes_downloaded, total_bytes
- for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
- headers_factory: optional function to inject headers
- Returns:
- Tuple of file path, total number of downloaded bytes.
- Raises:
- ApiException: raised on unexpected server
- responses (HTTP status codes other than 206, 413, 429, 503)
- urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
- after retrying: connection errors, timeouts, decode errors,
- invalid HTTP headers, payload too large (HTTP 413), too many
- requests (HTTP 429), service unavailable (HTTP 503)
- """
- # cast string to Path if necessary
- if isinstance(target, str):
- target = Path(target)
- # if it is a directory, take the filename from the url
- if target.is_dir():
- target = target / urlparse(url)[2].rsplit("/", 1)[-1]
- # open the file
- try:
- with target.open("wb") as fileobj:
- size = download_fileobj(
- url,
- fileobj,
- chunk_size=chunk_size,
- timeout=timeout,
- pool=pool,
- callback_func=callback_func,
- headers_factory=headers_factory,
- )
- except Exception:
- # Clean up a partially downloaded file
- try:
- os.remove(target)
- except FileNotFoundError:
- pass
- raise
- return target, size
- def download_fileobj(
- url: str,
- fileobj: BinaryIO,
- chunk_size: int = 16777216,
- timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
- pool: Optional[urllib3.PoolManager] = None,
- callback_func: Optional[Callable[[int, int], None]] = None,
- headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
- ) -> int:
- """Download a url to a file object using multiple requests.
- It is assumed that the file server supports multipart downloads (range
- requests).
- Args:
- url: The url to retrieve.
- fileobj: The (binary) file object to write into.
- chunk_size: The number of bytes per request. Default: 16MB.
- timeout: The total timeout in seconds.
- pool: If not supplied, a default connection pool will be
- created with a retry policy of 3 retries after 1, 2, 4 seconds.
- callback_func: optional function used to receive: bytes_downloaded, total_bytes
- for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
- headers_factory: optional function to inject headers
- Returns:
- The total number of downloaded bytes.
- Raises:
- ApiException: raised on unexpected server
- responses (HTTP status codes other than 206, 413, 429, 503)
- urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
- after retrying: connection errors, timeouts, decode errors,
- invalid HTTP headers, payload too large (HTTP 413), too many
- requests (HTTP 429), service unavailable (HTTP 503)
- Note that the fileobj might be partially filled with data in case of
- an exception.
- """
- if pool is None:
- pool = get_pool()
- if headers_factory is not None:
- base_headers = headers_factory()
- if any(x.lower() == "range" for x in base_headers):
- raise ValueError("Cannot set the Range header through header_factory")
- else:
- base_headers = {}
- # Our strategy here is to just start downloading chunks while monitoring
- # the Content-Range header to check if we're done. Although we could get
- # the total Content-Length from a HEAD request, not all servers support
- # that (e.g. Minio).
- start = 0
- while True:
- # download a chunk
- stop = start + chunk_size - 1
- headers = {"Range": "bytes={}-{}".format(start, stop), **base_headers}
- response = pool.request(
- "GET",
- url,
- headers=headers,
- timeout=timeout,
- )
- if response.status == HTTPStatus.OK:
- raise ApiException(
- "The file server does not support multipart downloads.",
- status=response.status,
- )
- elif response.status != HTTPStatus.PARTIAL_CONTENT:
- raise ApiException("Unexpected status", status=response.status)
- # write to file
- fileobj.write(response.data)
- # parse content-range header (e.g. "bytes 0-3/7") for next iteration
- content_range = response.headers["Content-Range"]
- start, stop, total = [
- int(x) for x in CONTENT_RANGE_REGEXP.findall(content_range)[0]
- ]
- if callable(callback_func):
- download_bytes: int = total if stop + 1 >= total else stop
- callback_func(download_bytes, total)
- if stop + 1 >= total:
- break
- start += chunk_size
- return total
- def upload_file(
- url: str,
- file_path: Path,
- chunk_size: int = 16777216,
- timeout: Optional[Union[float, urllib3.Timeout]] = None,
- pool: Optional[urllib3.PoolManager] = None,
- md5: Optional[bytes] = None,
- callback_func: Optional[Callable[[int, int], None]] = None,
- headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
- ) -> int:
- """Upload a file at specified file path to a url.
- The upload is accompanied by an MD5 hash so that the file server checks
- the integrity of the file.
- Args:
- url: The url to upload to.
- file_path: The file path to read data from.
- chunk_size: The size of the chunk in the streaming upload. Note that this
- function does not do multipart upload. Default: 16MB.
- timeout: The total timeout in seconds. The default is a connect timeout of
- 5 seconds and a read timeout of 10 minutes.
- pool: If not supplied, a default connection pool will be
- created with a retry policy of 3 retries after 1, 2, 4 seconds.
- md5: The MD5 digest (binary) of the file. Supply the MD5 to enable server-side
- integrity check. Note that when using presigned urls in AWS S3, the md5 hash
- should be included in the signing procedure.
- callback_func: optional function used to receive: bytes_uploaded, total_bytes
- for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
- headers_factory: optional function to inject headers
- Returns:
- The total number of uploaded bytes.
- Raises:
- IOError: Raised if the provided file is incompatible or empty.
- ApiException: raised on unexpected server
- responses (HTTP status codes other than 206, 413, 429, 503)
- urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
- after retrying: connection errors, timeouts, decode errors,
- invalid HTTP headers, payload too large (HTTP 413), too many
- requests (HTTP 429), service unavailable (HTTP 503)
- """
- # cast string to Path if necessary
- if isinstance(file_path, str):
- file_path = Path(file_path)
- # open the file
- with file_path.open("rb") as fileobj:
- size = upload_fileobj(
- url,
- fileobj,
- chunk_size=chunk_size,
- timeout=timeout,
- pool=pool,
- md5=md5,
- callback_func=callback_func,
- headers_factory=headers_factory,
- )
- return size
- def _iter_chunks(
- fileobj: BinaryIO,
- chunk_size: int,
- callback_func: Optional[Callable[[int], None]] = None,
- ):
- """Yield chunks from a file stream"""
- uploaded_bytes: int = 0
- assert chunk_size > 0
- while True:
- data = fileobj.read(chunk_size)
- if len(data) == 0:
- break
- uploaded_bytes += chunk_size
- if callable(callback_func):
- callback_func(uploaded_bytes)
- yield data
- class _SeekableChunkIterator:
- """A chunk iterator that can be rewinded in case of urllib3 retries."""
- def __init__(
- self,
- fileobj: BinaryIO,
- chunk_size: int,
- callback_func: Optional[Callable[[int], None]] = None,
- ):
- self.fileobj = fileobj
- self.chunk_size = chunk_size
- self.callback_func = callback_func
- def seek(self, pos: int):
- return self.fileobj.seek(pos)
- def tell(self):
- return self.fileobj.tell()
- def __iter__(self):
- return _iter_chunks(self.fileobj, self.chunk_size, self.callback_func)
- def upload_fileobj(
- url: str,
- fileobj: BinaryIO,
- chunk_size: int = 16777216,
- timeout: Optional[Union[float, urllib3.Timeout]] = None,
- pool: Optional[urllib3.PoolManager] = None,
- md5: Optional[bytes] = None,
- callback_func: Optional[Callable[[int, int], None]] = None,
- headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
- ) -> int:
- """Upload a file object to a url.
- The upload is accompanied by an MD5 hash so that the file server checks
- the integrity of the file.
- Args:
- url: The url to upload to.
- fileobj: The (binary) file object to read from.
- chunk_size: The size of the chunk in the streaming upload. Note that this
- function does not do multipart upload. Default: 16MB.
- timeout: The total timeout in seconds. The default is a connect timeout of
- 5 seconds and a read timeout of 10 minutes.
- pool: If not supplied, a default connection pool will be
- created with a retry policy of 3 retries after 1, 2, 4 seconds.
- md5: The MD5 digest (binary) of the file. Supply the MD5 to enable server-side
- integrity check. Note that when using presigned urls in AWS S3, the md5 hash
- should be included in the signing procedure.
- callback_func: optional function used to receive: bytes_uploaded, total_bytes
- for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
- headers_factory: optional function to inject headers
- Returns:
- The total number of uploaded bytes.
- Raises:
- IOError: Raised if the provided file is incompatible or empty.
- ApiException: raised on unexpected server
- responses (HTTP status codes other than 206, 413, 429, 503)
- urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
- after retrying: connection errors, timeouts, decode errors,
- invalid HTTP headers, payload too large (HTTP 413), too many
- requests (HTTP 429), service unavailable (HTTP 503)
- """
- # There are two ways to upload in S3 (Minio):
- # - PutObject: put the whole object in one time
- # - multipart upload: requires presigned urls for every part
- # We can only do the first option as we have no other presigned urls.
- # So we take the first option, but we do stream the request body in chunks.
- # We will get hard to understand tracebacks if the fileobj is not
- # in binary mode. So use a trick to see if fileobj is in binary mode:
- if not isinstance(fileobj.read(0), bytes):
- raise IOError(
- "The file object is not in binary mode. Please open with mode='rb'."
- )
- file_size = fileobj.seek(0, 2) # go to EOF
- if file_size == 0:
- raise IOError("The file object is empty.")
- if pool is None:
- pool = get_pool()
- fileobj.seek(0)
- def callback(uploaded_bytes: int):
- if callable(callback_func):
- if uploaded_bytes > file_size:
- uploaded_bytes = file_size
- callback_func(uploaded_bytes, file_size)
- iterable = _SeekableChunkIterator(
- fileobj,
- chunk_size=chunk_size,
- callback_func=callback,
- )
- # Tested: both Content-Length and Content-MD5 are checked by Minio
- headers = {
- "Content-Length": str(file_size),
- }
- if md5 is not None:
- headers["Content-MD5"] = base64.b64encode(md5).decode()
- if headers_factory is not None:
- headers.update(headers_factory())
- response = pool.request(
- "PUT",
- url,
- body=iterable,
- headers=headers,
- timeout=DEFAULT_UPLOAD_TIMEOUT if timeout is None else timeout,
- )
- if response.status not in {HTTPStatus.OK, HTTPStatus.CREATED}:
- raise ApiException("Unexpected status", status=response.status)
- return file_size
|