files.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import base64
  2. import hashlib
  3. import logging
  4. import os
  5. import re
  6. from http import HTTPStatus
  7. from pathlib import Path
  8. from typing import BinaryIO
  9. from typing import Callable
  10. from typing import Dict
  11. from typing import Optional
  12. from typing import Tuple
  13. from typing import Union
  14. from urllib.parse import urlparse
  15. import urllib3
  16. from .exceptions import ApiException
  17. __all__ = ["download_file", "download_fileobj", "upload_file", "upload_fileobj"]
  18. CONTENT_RANGE_REGEXP = re.compile(r"^bytes (\d+)-(\d+)/(\d+|\*)$")
  19. # Default upload timeout has an increased socket read timeout, because MinIO
  20. # takes very long for completing the upload for larger files. The limit of 10 minutes
  21. # should accomodate files up to 150 GB.
  22. DEFAULT_UPLOAD_TIMEOUT = urllib3.Timeout(connect=5.0, read=600.0)
  23. logger = logging.getLogger(__name__)
  24. def get_pool(retries: int = 3, backoff_factor: float = 1.0) -> urllib3.PoolManager:
  25. """Create a PoolManager with a retry policy.
  26. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  27. Args:
  28. retries: Total number of retries per request
  29. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  30. """
  31. return urllib3.PoolManager(
  32. retries=urllib3.util.retry.Retry(retries, backoff_factor=backoff_factor)
  33. )
  34. def compute_md5(fileobj: BinaryIO, chunk_size: int = 16777216):
  35. """Compute the MD5 checksum of a file object."""
  36. fileobj.seek(0)
  37. hasher = hashlib.md5()
  38. for chunk in _iter_chunks(fileobj, chunk_size=chunk_size):
  39. hasher.update(chunk)
  40. return hasher.digest()
  41. def download_file(
  42. url: str,
  43. target: Path,
  44. chunk_size: int = 16777216,
  45. timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
  46. pool: Optional[urllib3.PoolManager] = None,
  47. callback_func: Optional[Callable[[int, int], None]] = None,
  48. headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
  49. ) -> Tuple[Path, int]:
  50. """Download a file to a specified path on disk.
  51. It is assumed that the file server supports multipart downloads (range
  52. requests).
  53. Args:
  54. url: The url to retrieve.
  55. target: The location to copy to. If this is an existing file, it is
  56. overwritten. If it is a directory, a filename is generated from
  57. the filename in the url.
  58. chunk_size: The number of bytes per request. Default: 16MB.
  59. timeout: The total timeout in seconds.
  60. pool: If not supplied, a default connection pool will be
  61. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  62. callback_func: optional function used to receive: bytes_downloaded, total_bytes
  63. for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
  64. headers_factory: optional function to inject headers
  65. Returns:
  66. Tuple of file path, total number of downloaded bytes.
  67. Raises:
  68. ApiException: raised on unexpected server
  69. responses (HTTP status codes other than 206, 413, 429, 503)
  70. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  71. after retrying: connection errors, timeouts, decode errors,
  72. invalid HTTP headers, payload too large (HTTP 413), too many
  73. requests (HTTP 429), service unavailable (HTTP 503)
  74. """
  75. # cast string to Path if necessary
  76. if isinstance(target, str):
  77. target = Path(target)
  78. # if it is a directory, take the filename from the url
  79. if target.is_dir():
  80. target = target / urlparse(url)[2].rsplit("/", 1)[-1]
  81. # open the file
  82. try:
  83. with target.open("wb") as fileobj:
  84. size = download_fileobj(
  85. url,
  86. fileobj,
  87. chunk_size=chunk_size,
  88. timeout=timeout,
  89. pool=pool,
  90. callback_func=callback_func,
  91. headers_factory=headers_factory,
  92. )
  93. except Exception:
  94. # Clean up a partially downloaded file
  95. try:
  96. os.remove(target)
  97. except FileNotFoundError:
  98. pass
  99. raise
  100. return target, size
  101. def download_fileobj(
  102. url: str,
  103. fileobj: BinaryIO,
  104. chunk_size: int = 16777216,
  105. timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
  106. pool: Optional[urllib3.PoolManager] = None,
  107. callback_func: Optional[Callable[[int, int], None]] = None,
  108. headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
  109. ) -> int:
  110. """Download a url to a file object using multiple requests.
  111. It is assumed that the file server supports multipart downloads (range
  112. requests).
  113. Args:
  114. url: The url to retrieve.
  115. fileobj: The (binary) file object to write into.
  116. chunk_size: The number of bytes per request. Default: 16MB.
  117. timeout: The total timeout in seconds.
  118. pool: If not supplied, a default connection pool will be
  119. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  120. callback_func: optional function used to receive: bytes_downloaded, total_bytes
  121. for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
  122. headers_factory: optional function to inject headers
  123. Returns:
  124. The total number of downloaded bytes.
  125. Raises:
  126. ApiException: raised on unexpected server
  127. responses (HTTP status codes other than 206, 413, 429, 503)
  128. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  129. after retrying: connection errors, timeouts, decode errors,
  130. invalid HTTP headers, payload too large (HTTP 413), too many
  131. requests (HTTP 429), service unavailable (HTTP 503)
  132. Note that the fileobj might be partially filled with data in case of
  133. an exception.
  134. """
  135. if pool is None:
  136. pool = get_pool()
  137. if headers_factory is not None:
  138. base_headers = headers_factory()
  139. if any(x.lower() == "range" for x in base_headers):
  140. raise ValueError("Cannot set the Range header through header_factory")
  141. else:
  142. base_headers = {}
  143. # Our strategy here is to just start downloading chunks while monitoring
  144. # the Content-Range header to check if we're done. Although we could get
  145. # the total Content-Length from a HEAD request, not all servers support
  146. # that (e.g. Minio).
  147. start = 0
  148. while True:
  149. # download a chunk
  150. stop = start + chunk_size - 1
  151. headers = {"Range": "bytes={}-{}".format(start, stop), **base_headers}
  152. response = pool.request(
  153. "GET",
  154. url,
  155. headers=headers,
  156. timeout=timeout,
  157. )
  158. if response.status == HTTPStatus.OK:
  159. raise ApiException(
  160. "The file server does not support multipart downloads.",
  161. status=response.status,
  162. )
  163. elif response.status != HTTPStatus.PARTIAL_CONTENT:
  164. raise ApiException("Unexpected status", status=response.status)
  165. # write to file
  166. fileobj.write(response.data)
  167. # parse content-range header (e.g. "bytes 0-3/7") for next iteration
  168. content_range = response.headers["Content-Range"]
  169. start, stop, total = [
  170. int(x) for x in CONTENT_RANGE_REGEXP.findall(content_range)[0]
  171. ]
  172. if callable(callback_func):
  173. download_bytes: int = total if stop + 1 >= total else stop
  174. callback_func(download_bytes, total)
  175. if stop + 1 >= total:
  176. break
  177. start += chunk_size
  178. return total
  179. def upload_file(
  180. url: str,
  181. file_path: Path,
  182. chunk_size: int = 16777216,
  183. timeout: Optional[Union[float, urllib3.Timeout]] = None,
  184. pool: Optional[urllib3.PoolManager] = None,
  185. md5: Optional[bytes] = None,
  186. callback_func: Optional[Callable[[int, int], None]] = None,
  187. headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
  188. ) -> int:
  189. """Upload a file at specified file path to a url.
  190. The upload is accompanied by an MD5 hash so that the file server checks
  191. the integrity of the file.
  192. Args:
  193. url: The url to upload to.
  194. file_path: The file path to read data from.
  195. chunk_size: The size of the chunk in the streaming upload. Note that this
  196. function does not do multipart upload. Default: 16MB.
  197. timeout: The total timeout in seconds. The default is a connect timeout of
  198. 5 seconds and a read timeout of 10 minutes.
  199. pool: If not supplied, a default connection pool will be
  200. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  201. md5: The MD5 digest (binary) of the file. Supply the MD5 to enable server-side
  202. integrity check. Note that when using presigned urls in AWS S3, the md5 hash
  203. should be included in the signing procedure.
  204. callback_func: optional function used to receive: bytes_uploaded, total_bytes
  205. for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
  206. headers_factory: optional function to inject headers
  207. Returns:
  208. The total number of uploaded bytes.
  209. Raises:
  210. IOError: Raised if the provided file is incompatible or empty.
  211. ApiException: raised on unexpected server
  212. responses (HTTP status codes other than 206, 413, 429, 503)
  213. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  214. after retrying: connection errors, timeouts, decode errors,
  215. invalid HTTP headers, payload too large (HTTP 413), too many
  216. requests (HTTP 429), service unavailable (HTTP 503)
  217. """
  218. # cast string to Path if necessary
  219. if isinstance(file_path, str):
  220. file_path = Path(file_path)
  221. # open the file
  222. with file_path.open("rb") as fileobj:
  223. size = upload_fileobj(
  224. url,
  225. fileobj,
  226. chunk_size=chunk_size,
  227. timeout=timeout,
  228. pool=pool,
  229. md5=md5,
  230. callback_func=callback_func,
  231. headers_factory=headers_factory,
  232. )
  233. return size
  234. def _iter_chunks(
  235. fileobj: BinaryIO,
  236. chunk_size: int,
  237. callback_func: Optional[Callable[[int], None]] = None,
  238. ):
  239. """Yield chunks from a file stream"""
  240. uploaded_bytes: int = 0
  241. assert chunk_size > 0
  242. while True:
  243. data = fileobj.read(chunk_size)
  244. if len(data) == 0:
  245. break
  246. uploaded_bytes += chunk_size
  247. if callable(callback_func):
  248. callback_func(uploaded_bytes)
  249. yield data
  250. class _SeekableChunkIterator:
  251. """A chunk iterator that can be rewinded in case of urllib3 retries."""
  252. def __init__(
  253. self,
  254. fileobj: BinaryIO,
  255. chunk_size: int,
  256. callback_func: Optional[Callable[[int], None]] = None,
  257. ):
  258. self.fileobj = fileobj
  259. self.chunk_size = chunk_size
  260. self.callback_func = callback_func
  261. def seek(self, pos: int):
  262. return self.fileobj.seek(pos)
  263. def tell(self):
  264. return self.fileobj.tell()
  265. def __iter__(self):
  266. return _iter_chunks(self.fileobj, self.chunk_size, self.callback_func)
  267. def upload_fileobj(
  268. url: str,
  269. fileobj: BinaryIO,
  270. chunk_size: int = 16777216,
  271. timeout: Optional[Union[float, urllib3.Timeout]] = None,
  272. pool: Optional[urllib3.PoolManager] = None,
  273. md5: Optional[bytes] = None,
  274. callback_func: Optional[Callable[[int, int], None]] = None,
  275. headers_factory: Optional[Callable[[], Dict[str, str]]] = None,
  276. ) -> int:
  277. """Upload a file object to a url.
  278. The upload is accompanied by an MD5 hash so that the file server checks
  279. the integrity of the file.
  280. Args:
  281. url: The url to upload to.
  282. fileobj: The (binary) file object to read from.
  283. chunk_size: The size of the chunk in the streaming upload. Note that this
  284. function does not do multipart upload. Default: 16MB.
  285. timeout: The total timeout in seconds. The default is a connect timeout of
  286. 5 seconds and a read timeout of 10 minutes.
  287. pool: If not supplied, a default connection pool will be
  288. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  289. md5: The MD5 digest (binary) of the file. Supply the MD5 to enable server-side
  290. integrity check. Note that when using presigned urls in AWS S3, the md5 hash
  291. should be included in the signing procedure.
  292. callback_func: optional function used to receive: bytes_uploaded, total_bytes
  293. for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
  294. headers_factory: optional function to inject headers
  295. Returns:
  296. The total number of uploaded bytes.
  297. Raises:
  298. IOError: Raised if the provided file is incompatible or empty.
  299. ApiException: raised on unexpected server
  300. responses (HTTP status codes other than 206, 413, 429, 503)
  301. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  302. after retrying: connection errors, timeouts, decode errors,
  303. invalid HTTP headers, payload too large (HTTP 413), too many
  304. requests (HTTP 429), service unavailable (HTTP 503)
  305. """
  306. # There are two ways to upload in S3 (Minio):
  307. # - PutObject: put the whole object in one time
  308. # - multipart upload: requires presigned urls for every part
  309. # We can only do the first option as we have no other presigned urls.
  310. # So we take the first option, but we do stream the request body in chunks.
  311. # We will get hard to understand tracebacks if the fileobj is not
  312. # in binary mode. So use a trick to see if fileobj is in binary mode:
  313. if not isinstance(fileobj.read(0), bytes):
  314. raise IOError(
  315. "The file object is not in binary mode. Please open with mode='rb'."
  316. )
  317. file_size = fileobj.seek(0, 2) # go to EOF
  318. if file_size == 0:
  319. raise IOError("The file object is empty.")
  320. if pool is None:
  321. pool = get_pool()
  322. fileobj.seek(0)
  323. def callback(uploaded_bytes: int):
  324. if callable(callback_func):
  325. if uploaded_bytes > file_size:
  326. uploaded_bytes = file_size
  327. callback_func(uploaded_bytes, file_size)
  328. iterable = _SeekableChunkIterator(
  329. fileobj,
  330. chunk_size=chunk_size,
  331. callback_func=callback,
  332. )
  333. # Tested: both Content-Length and Content-MD5 are checked by Minio
  334. headers = {
  335. "Content-Length": str(file_size),
  336. }
  337. if md5 is not None:
  338. headers["Content-MD5"] = base64.b64encode(md5).decode()
  339. if headers_factory is not None:
  340. headers.update(headers_factory())
  341. response = pool.request(
  342. "PUT",
  343. url,
  344. body=iterable,
  345. headers=headers,
  346. timeout=DEFAULT_UPLOAD_TIMEOUT if timeout is None else timeout,
  347. )
  348. if response.status not in {HTTPStatus.OK, HTTPStatus.CREATED}:
  349. raise ApiException("Unexpected status", status=response.status)
  350. return file_size