files.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import base64
  2. import hashlib
  3. import logging
  4. import os
  5. import re
  6. from pathlib import Path
  7. from typing import BinaryIO
  8. from typing import Callable
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. from urllib.parse import urlparse
  13. import urllib3
  14. from .exceptions import ApiException
  15. __all__ = ["download_file", "download_fileobj", "upload_file", "upload_fileobj"]
  16. CONTENT_RANGE_REGEXP = re.compile(r"^bytes (\d+)-(\d+)/(\d+|\*)$")
  17. # Default upload timeout has an increased socket read timeout, because MinIO
  18. # takes very long for completing the upload for larger files. The limit of 10 minutes
  19. # should accomodate files up to 150 GB.
  20. DEFAULT_UPLOAD_TIMEOUT = urllib3.Timeout(connect=5.0, read=600.0)
  21. logger = logging.getLogger(__name__)
  22. def get_pool(retries: int = 3, backoff_factor: float = 1.0) -> urllib3.PoolManager:
  23. """Create a PoolManager with a retry policy.
  24. The default retry policy has 3 retries with 1, 2, 4 second intervals.
  25. Args:
  26. retries: Total number of retries per request
  27. backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
  28. """
  29. return urllib3.PoolManager(
  30. retries=urllib3.util.retry.Retry(retries, backoff_factor=backoff_factor)
  31. )
  32. def compute_md5(fileobj: BinaryIO, chunk_size: int = 16777216):
  33. """Compute the MD5 checksum of a file object."""
  34. fileobj.seek(0)
  35. hasher = hashlib.md5()
  36. for chunk in _iter_chunks(fileobj, chunk_size=chunk_size):
  37. hasher.update(chunk)
  38. return hasher.digest()
  39. def download_file(
  40. url: str,
  41. target: Path,
  42. chunk_size: int = 16777216,
  43. timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
  44. pool: Optional[urllib3.PoolManager] = None,
  45. callback_func: Optional[Callable[[int, int], None]] = None,
  46. ) -> Tuple[Path, int]:
  47. """Download a file to a specified path on disk.
  48. It is assumed that the file server supports multipart downloads (range
  49. requests).
  50. Args:
  51. url: The url to retrieve.
  52. target: The location to copy to. If this is an existing file, it is
  53. overwritten. If it is a directory, a filename is generated from
  54. the filename in the url.
  55. chunk_size: The number of bytes per request. Default: 16MB.
  56. timeout: The total timeout in seconds.
  57. pool: If not supplied, a default connection pool will be
  58. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  59. callback_func: optional function used to receive: bytes_downloaded, total_bytes
  60. for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
  61. Returns:
  62. Tuple of file path, total number of downloaded bytes.
  63. Raises:
  64. ApiException: raised on unexpected server
  65. responses (HTTP status codes other than 206, 413, 429, 503)
  66. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  67. after retrying: connection errors, timeouts, decode errors,
  68. invalid HTTP headers, payload too large (HTTP 413), too many
  69. requests (HTTP 429), service unavailable (HTTP 503)
  70. """
  71. # cast string to Path if necessary
  72. if isinstance(target, str):
  73. target = Path(target)
  74. # if it is a directory, take the filename from the url
  75. if target.is_dir():
  76. target = target / urlparse(url)[2].rsplit("/", 1)[-1]
  77. # open the file
  78. try:
  79. with target.open("wb") as fileobj:
  80. size = download_fileobj(
  81. url,
  82. fileobj,
  83. chunk_size=chunk_size,
  84. timeout=timeout,
  85. pool=pool,
  86. callback_func=callback_func,
  87. )
  88. except Exception:
  89. # Clean up a partially downloaded file
  90. try:
  91. os.remove(target)
  92. except FileNotFoundError:
  93. pass
  94. raise
  95. return target, size
  96. def download_fileobj(
  97. url: str,
  98. fileobj: BinaryIO,
  99. chunk_size: int = 16777216,
  100. timeout: Optional[Union[float, urllib3.Timeout]] = 5.0,
  101. pool: Optional[urllib3.PoolManager] = None,
  102. callback_func: Optional[Callable[[int, int], None]] = None,
  103. ) -> int:
  104. """Download a url to a file object using multiple requests.
  105. It is assumed that the file server supports multipart downloads (range
  106. requests).
  107. Args:
  108. url: The url to retrieve.
  109. fileobj: The (binary) file object to write into.
  110. chunk_size: The number of bytes per request. Default: 16MB.
  111. timeout: The total timeout in seconds.
  112. pool: If not supplied, a default connection pool will be
  113. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  114. callback_func: optional function used to receive: bytes_downloaded, total_bytes
  115. for example: def callback(bytes_downloaded: int, total_bytes: int) -> None
  116. Returns:
  117. The total number of downloaded bytes.
  118. Raises:
  119. ApiException: raised on unexpected server
  120. responses (HTTP status codes other than 206, 413, 429, 503)
  121. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  122. after retrying: connection errors, timeouts, decode errors,
  123. invalid HTTP headers, payload too large (HTTP 413), too many
  124. requests (HTTP 429), service unavailable (HTTP 503)
  125. Note that the fileobj might be partially filled with data in case of
  126. an exception.
  127. """
  128. if pool is None:
  129. pool = get_pool()
  130. # Our strategy here is to just start downloading chunks while monitoring
  131. # the Content-Range header to check if we're done. Although we could get
  132. # the total Content-Length from a HEAD request, not all servers support
  133. # that (e.g. Minio).
  134. start = 0
  135. while True:
  136. # download a chunk
  137. stop = start + chunk_size - 1
  138. headers = {"Range": "bytes={}-{}".format(start, stop)}
  139. response = pool.request(
  140. "GET",
  141. url,
  142. headers=headers,
  143. timeout=timeout,
  144. )
  145. if response.status == 200:
  146. raise ApiException(
  147. "The file server does not support multipart downloads.",
  148. status=response.status,
  149. )
  150. elif response.status != 206:
  151. raise ApiException("Unexpected status", status=response.status)
  152. # write to file
  153. fileobj.write(response.data)
  154. # parse content-range header (e.g. "bytes 0-3/7") for next iteration
  155. content_range = response.headers["Content-Range"]
  156. start, stop, total = [
  157. int(x) for x in CONTENT_RANGE_REGEXP.findall(content_range)[0]
  158. ]
  159. if callable(callback_func):
  160. download_bytes: int = total if stop + 1 >= total else stop
  161. callback_func(download_bytes, total)
  162. if stop + 1 >= total:
  163. break
  164. start += chunk_size
  165. return total
  166. def upload_file(
  167. url: str,
  168. file_path: Path,
  169. chunk_size: int = 16777216,
  170. timeout: Optional[Union[float, urllib3.Timeout]] = None,
  171. pool: Optional[urllib3.PoolManager] = None,
  172. md5: Optional[bytes] = None,
  173. callback_func: Optional[Callable[[int, int], None]] = None,
  174. ) -> int:
  175. """Upload a file at specified file path to a url.
  176. The upload is accompanied by an MD5 hash so that the file server checks
  177. the integrity of the file.
  178. Args:
  179. url: The url to upload to.
  180. file_path: The file path to read data from.
  181. chunk_size: The size of the chunk in the streaming upload. Note that this
  182. function does not do multipart upload. Default: 16MB.
  183. timeout: The total timeout in seconds. The default is a connect timeout of
  184. 5 seconds and a read timeout of 10 minutes.
  185. pool: If not supplied, a default connection pool will be
  186. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  187. md5: The MD5 digest (binary) of the file. Supply the MD5 to enable server-side
  188. integrity check. Note that when using presigned urls in AWS S3, the md5 hash
  189. should be included in the signing procedure.
  190. callback_func: optional function used to receive: bytes_uploaded, total_bytes
  191. for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
  192. Returns:
  193. The total number of uploaded bytes.
  194. Raises:
  195. IOError: Raised if the provided file is incompatible or empty.
  196. ApiException: raised on unexpected server
  197. responses (HTTP status codes other than 206, 413, 429, 503)
  198. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  199. after retrying: connection errors, timeouts, decode errors,
  200. invalid HTTP headers, payload too large (HTTP 413), too many
  201. requests (HTTP 429), service unavailable (HTTP 503)
  202. """
  203. # cast string to Path if necessary
  204. if isinstance(file_path, str):
  205. file_path = Path(file_path)
  206. # open the file
  207. with file_path.open("rb") as fileobj:
  208. size = upload_fileobj(
  209. url,
  210. fileobj,
  211. chunk_size=chunk_size,
  212. timeout=timeout,
  213. pool=pool,
  214. md5=md5,
  215. callback_func=callback_func,
  216. )
  217. return size
  218. def _iter_chunks(
  219. fileobj: BinaryIO,
  220. chunk_size: int,
  221. callback_func: Optional[Callable[[int], None]] = None,
  222. ):
  223. """Yield chunks from a file stream"""
  224. uploaded_bytes: int = 0
  225. assert chunk_size > 0
  226. while True:
  227. data = fileobj.read(chunk_size)
  228. if len(data) == 0:
  229. break
  230. uploaded_bytes += chunk_size
  231. if callable(callback_func):
  232. callback_func(uploaded_bytes)
  233. yield data
  234. class _SeekableChunkIterator:
  235. """A chunk iterator that can be rewinded in case of urllib3 retries."""
  236. def __init__(
  237. self,
  238. fileobj: BinaryIO,
  239. chunk_size: int,
  240. callback_func: Optional[Callable[[int], None]] = None,
  241. ):
  242. self.fileobj = fileobj
  243. self.chunk_size = chunk_size
  244. self.callback_func = callback_func
  245. def seek(self, pos: int):
  246. return self.fileobj.seek(pos)
  247. def tell(self):
  248. return self.fileobj.tell()
  249. def __iter__(self):
  250. return _iter_chunks(self.fileobj, self.chunk_size, self.callback_func)
  251. def upload_fileobj(
  252. url: str,
  253. fileobj: BinaryIO,
  254. chunk_size: int = 16777216,
  255. timeout: Optional[Union[float, urllib3.Timeout]] = None,
  256. pool: Optional[urllib3.PoolManager] = None,
  257. md5: Optional[bytes] = None,
  258. callback_func: Optional[Callable[[int, int], None]] = None,
  259. ) -> int:
  260. """Upload a file object to a url.
  261. The upload is accompanied by an MD5 hash so that the file server checks
  262. the integrity of the file.
  263. Args:
  264. url: The url to upload to.
  265. fileobj: The (binary) file object to read from.
  266. chunk_size: The size of the chunk in the streaming upload. Note that this
  267. function does not do multipart upload. Default: 16MB.
  268. timeout: The total timeout in seconds. The default is a connect timeout of
  269. 5 seconds and a read timeout of 10 minutes.
  270. pool: If not supplied, a default connection pool will be
  271. created with a retry policy of 3 retries after 1, 2, 4 seconds.
  272. md5: The MD5 digest (binary) of the file. Supply the MD5 to enable server-side
  273. integrity check. Note that when using presigned urls in AWS S3, the md5 hash
  274. should be included in the signing procedure.
  275. callback_func: optional function used to receive: bytes_uploaded, total_bytes
  276. for example: def callback(bytes_uploaded: int, total_bytes: int) -> None
  277. Returns:
  278. The total number of uploaded bytes.
  279. Raises:
  280. IOError: Raised if the provided file is incompatible or empty.
  281. ApiException: raised on unexpected server
  282. responses (HTTP status codes other than 206, 413, 429, 503)
  283. urllib3.exceptions.HTTPError: various low-level HTTP errors that persist
  284. after retrying: connection errors, timeouts, decode errors,
  285. invalid HTTP headers, payload too large (HTTP 413), too many
  286. requests (HTTP 429), service unavailable (HTTP 503)
  287. """
  288. # There are two ways to upload in S3 (Minio):
  289. # - PutObject: put the whole object in one time
  290. # - multipart upload: requires presigned urls for every part
  291. # We can only do the first option as we have no other presigned urls.
  292. # So we take the first option, but we do stream the request body in chunks.
  293. # We will get hard to understand tracebacks if the fileobj is not
  294. # in binary mode. So use a trick to see if fileobj is in binary mode:
  295. if not isinstance(fileobj.read(0), bytes):
  296. raise IOError(
  297. "The file object is not in binary mode. Please open with mode='rb'."
  298. )
  299. file_size = fileobj.seek(0, 2) # go to EOF
  300. if file_size == 0:
  301. raise IOError("The file object is empty.")
  302. if pool is None:
  303. pool = get_pool()
  304. fileobj.seek(0)
  305. def callback(uploaded_bytes: int):
  306. if callable(callback_func):
  307. if uploaded_bytes > file_size:
  308. uploaded_bytes = file_size
  309. callback_func(uploaded_bytes, file_size)
  310. iterable = _SeekableChunkIterator(
  311. fileobj,
  312. chunk_size=chunk_size,
  313. callback_func=callback,
  314. )
  315. # Tested: both Content-Length and Content-MD5 are checked by Minio
  316. headers = {
  317. "Content-Length": str(file_size),
  318. }
  319. if md5 is not None:
  320. headers["Content-MD5"] = base64.b64encode(md5).decode()
  321. response = pool.request(
  322. "PUT",
  323. url,
  324. body=iterable,
  325. headers=headers,
  326. timeout=DEFAULT_UPLOAD_TIMEOUT if timeout is None else timeout,
  327. )
  328. if response.status != 200:
  329. raise ApiException("Unexpected status", status=response.status)
  330. return file_size