s3_gateway.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. import logging
  4. from pathlib import Path
  5. from typing import List
  6. from typing import Optional
  7. import inject
  8. from botocore.exceptions import ClientError
  9. from pydantic import AnyHttpUrl
  10. from clean_python import ctx
  11. from clean_python import DoesNotExist
  12. from clean_python import Filter
  13. from clean_python import Gateway
  14. from clean_python import Id
  15. from clean_python import Json
  16. from clean_python import PageOptions
  17. from .s3_provider import S3BucketProvider
  18. DEFAULT_EXPIRY = 3600 # in seconds
  19. DEFAULT_TIMEOUT = 1.0
  20. AWS_LIMIT = 1000 # max s3 keys per request
  21. __all__ = ["S3Gateway"]
  22. logger = logging.getLogger(__name__)
  23. class S3Gateway(Gateway):
  24. """The interface to S3 Buckets.
  25. The standard Gateway interface is only partially implemented:
  26. - get() and filter() return metadata
  27. - add(), update(), upsert() are not implemented
  28. - remove() works as expected
  29. For actually getting the object data either use the download_file()
  30. or upload_file() or create a presigned url and hand that over to
  31. the client.
  32. """
  33. def __init__(
  34. self,
  35. provider_override: Optional[S3BucketProvider] = None,
  36. multitenant: bool = False,
  37. ):
  38. self.provider_override = provider_override
  39. self.multitenant = multitenant
  40. @property
  41. def provider(self):
  42. return self.provider_override or inject.instance(S3BucketProvider)
  43. def _id_to_key(self, id: Id) -> str:
  44. if not self.multitenant:
  45. return str(id)
  46. if ctx.tenant is None:
  47. raise RuntimeError(f"{self.__class__} requires a tenant in the context")
  48. return f"tenant-{ctx.tenant.id}/{id}"
  49. def _key_to_id(self, key: str) -> Id:
  50. return key.split("/", 1)[1] if self.multitenant else key
  51. async def get(self, id: Id) -> Optional[Json]:
  52. async with self.provider.client as client:
  53. try:
  54. result = await client.head_object(
  55. Bucket=self.provider.bucket, Key=self._id_to_key(id)
  56. )
  57. except ClientError as e:
  58. if e.response["Error"]["Code"] == "404":
  59. return None
  60. else:
  61. raise e
  62. return {
  63. "id": str(id),
  64. "last_modified": result["LastModified"],
  65. "etag": result["ETag"].strip('"'),
  66. "size": result["ContentLength"],
  67. }
  68. async def filter(
  69. self,
  70. filters: List[Filter],
  71. params: Optional[PageOptions] = PageOptions(limit=AWS_LIMIT),
  72. ) -> List[Json]:
  73. assert params is not None, "pagination is required for S3Gateway"
  74. assert params.limit <= AWS_LIMIT, f"max {AWS_LIMIT} keys for S3Gateway"
  75. assert params.offset == 0, "no 'offset' pagination for S3Gateway"
  76. assert params.order_by == "id", "can order by 'id' only for S3Gateway"
  77. kwargs = {
  78. "Bucket": self.provider.bucket,
  79. "MaxKeys": params.limit,
  80. "Prefix": self.filters_to_prefix(filters),
  81. }
  82. if params.cursor is not None:
  83. kwargs["StartAfter"] = self._id_to_key(params.cursor)
  84. async with self.provider.client as client:
  85. result = await client.list_objects_v2(**kwargs)
  86. # Example response:
  87. # {
  88. # 'Key': 'object-in-s3',
  89. # 'LastModified': datetime.datetime(..., tzinfo=utc),
  90. # 'ETag': '"acbd18db4cc2f85cedef654fccc4a4d8"',
  91. # 'Size': 3, 'StorageClass':
  92. # 'STANDARD',
  93. # 'Owner': {...}
  94. # }
  95. return [
  96. {
  97. "id": self._key_to_id(x["Key"]),
  98. "last_modified": x["LastModified"],
  99. "etag": x["ETag"].strip('"'),
  100. "size": x["Size"],
  101. }
  102. for x in result.get("Contents", [])
  103. ]
  104. async def remove(self, id: Id) -> bool:
  105. async with self.provider.client as client:
  106. await client.delete_object(
  107. Bucket=self.provider.bucket,
  108. Key=self._id_to_key(id),
  109. )
  110. # S3 doesn't tell us if the object was there in the first place
  111. return True
  112. async def remove_multiple(self, ids: List[Id]) -> None:
  113. if len(ids) == 0:
  114. return
  115. assert len(ids) <= AWS_LIMIT, f"max {AWS_LIMIT} keys for S3Gateway"
  116. async with self.provider.client as client:
  117. await client.delete_objects(
  118. Bucket=self.provider.bucket,
  119. Delete={
  120. "Objects": [{"Key": self._id_to_key(x)} for x in ids],
  121. "Quiet": True,
  122. },
  123. )
  124. async def _create_presigned_url(
  125. self,
  126. id: Id,
  127. client_method: str,
  128. ) -> AnyHttpUrl:
  129. async with self.provider.client as client:
  130. return await client.generate_presigned_url(
  131. client_method,
  132. Params={"Bucket": self.provider.bucket, "Key": self._id_to_key(id)},
  133. ExpiresIn=DEFAULT_EXPIRY,
  134. )
  135. async def create_download_url(self, id: Id) -> AnyHttpUrl:
  136. return await self._create_presigned_url(id, "get_object")
  137. async def create_upload_url(self, id: Id) -> AnyHttpUrl:
  138. return await self._create_presigned_url(id, "put_object")
  139. async def download_file(self, id: Id, file_path: Path) -> None:
  140. if file_path.exists():
  141. raise FileExistsError()
  142. try:
  143. async with self.provider.client as client:
  144. await client.download_file(
  145. Bucket=self.provider.bucket,
  146. Key=self._id_to_key(id),
  147. Filename=str(file_path),
  148. )
  149. except ClientError as e:
  150. if e.response["Error"]["Code"] == "404":
  151. file_path.unlink(missing_ok=True)
  152. raise DoesNotExist("object")
  153. else:
  154. raise e
  155. async def upload_file(self, id: Id, file_path: Path) -> None:
  156. if not file_path.is_file():
  157. raise FileNotFoundError()
  158. async with self.provider.client as client:
  159. await client.upload_file(
  160. Bucket=self.provider.bucket,
  161. Key=self._id_to_key(id),
  162. Filename=str(file_path),
  163. )
  164. def filters_to_prefix(self, filters: List[Filter]) -> str:
  165. if len(filters) == 0:
  166. return self._id_to_key("")
  167. elif len(filters) > 1:
  168. raise NotImplementedError("More than 1 filter is not supported")
  169. (filter,) = filters
  170. if filter.field == "prefix":
  171. assert len(filter.values) == 1
  172. return self._id_to_key(filter.values[0])
  173. else:
  174. raise NotImplementedError(f"Unsupported filter '{filter.field}'")
  175. async def remove_filtered(self, filters: List[Filter]) -> None:
  176. kwargs = {
  177. "Bucket": self.provider.bucket,
  178. "MaxKeys": AWS_LIMIT,
  179. "Prefix": self.filters_to_prefix(filters),
  180. }
  181. async with self.provider.client as client:
  182. while True:
  183. result = await client.list_objects_v2(**kwargs)
  184. contents = result.get("Contents", [])
  185. if contents:
  186. await client.delete_objects(
  187. Bucket=self.provider.bucket,
  188. Delete={
  189. "Objects": [{"Key": x["Key"]} for x in contents],
  190. "Quiet": True,
  191. },
  192. )
  193. if len(contents) < AWS_LIMIT:
  194. break
  195. kwargs["StartAfter"] = contents[-1]["Key"]