| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 | # -*- coding: utf-8 -*-# (c) Nelen & Schuurmansimport loggingfrom pathlib import Pathfrom typing import Listfrom typing import Optionalimport injectfrom botocore.exceptions import ClientErrorfrom pydantic import AnyHttpUrlfrom clean_python import ctxfrom clean_python import DoesNotExistfrom clean_python import Filterfrom clean_python import Gatewayfrom clean_python import Idfrom clean_python import Jsonfrom clean_python import PageOptionsfrom .s3_provider import S3BucketProviderDEFAULT_EXPIRY = 3600  # in secondsDEFAULT_TIMEOUT = 1.0AWS_LIMIT = 1000  # max s3 keys per request__all__ = ["S3Gateway"]logger = logging.getLogger(__name__)class S3Gateway(Gateway):    """The interface to S3 Buckets.    The standard Gateway interface is only partially implemented:    - get() and filter() return metadata    - add(), update(), upsert() are not implemented    - remove() works as expected    For actually getting the object data either use the download_file()    or upload_file() or create a presigned url and hand that over to    the client.    """    def __init__(        self,        provider_override: Optional[S3BucketProvider] = None,        multitenant: bool = False,    ):        self.provider_override = provider_override        self.multitenant = multitenant    @property    def provider(self):        return self.provider_override or inject.instance(S3BucketProvider)    def _id_to_key(self, id: Id) -> str:        if not self.multitenant:            return str(id)        if ctx.tenant is None:            raise RuntimeError(f"{self.__class__} requires a tenant in the context")        return f"tenant-{ctx.tenant.id}/{id}"    def _key_to_id(self, key: str) -> Id:        return key.split("/", 1)[1] if self.multitenant else key    async def get(self, id: Id) -> Optional[Json]:        async with self.provider.client as client:            try:                result = await client.head_object(                    Bucket=self.provider.bucket, Key=self._id_to_key(id)                )            except ClientError as e:                if e.response["Error"]["Code"] == "404":                    return None                else:                    raise e        return {            "id": str(id),            "last_modified": result["LastModified"],            "etag": result["ETag"].strip('"'),            "size": result["ContentLength"],        }    async def filter(        self,        filters: List[Filter],        params: Optional[PageOptions] = PageOptions(limit=AWS_LIMIT),    ) -> List[Json]:        assert params is not None, "pagination is required for S3Gateway"        assert params.limit <= AWS_LIMIT, f"max {AWS_LIMIT} keys for S3Gateway"        assert params.offset == 0, "no 'offset' pagination for S3Gateway"        assert params.order_by == "id", "can order by 'id' only for S3Gateway"        kwargs = {            "Bucket": self.provider.bucket,            "MaxKeys": params.limit,            "Prefix": self.filters_to_prefix(filters),        }        if params.cursor is not None:            kwargs["StartAfter"] = self._id_to_key(params.cursor)        async with self.provider.client as client:            result = await client.list_objects_v2(**kwargs)        # Example response:        #     {        #         'Key': 'object-in-s3',        #         'LastModified': datetime.datetime(..., tzinfo=utc),        #         'ETag': '"acbd18db4cc2f85cedef654fccc4a4d8"',        #         'Size': 3, 'StorageClass':        #         'STANDARD',        #         'Owner': {...}        #     }        return [            {                "id": self._key_to_id(x["Key"]),                "last_modified": x["LastModified"],                "etag": x["ETag"].strip('"'),                "size": x["Size"],            }            for x in result.get("Contents", [])        ]    async def remove(self, id: Id) -> bool:        async with self.provider.client as client:            await client.delete_object(                Bucket=self.provider.bucket,                Key=self._id_to_key(id),            )        # S3 doesn't tell us if the object was there in the first place        return True    async def remove_multiple(self, ids: List[Id]) -> None:        if len(ids) == 0:            return        assert len(ids) <= AWS_LIMIT, f"max {AWS_LIMIT} keys for S3Gateway"        async with self.provider.client as client:            await client.delete_objects(                Bucket=self.provider.bucket,                Delete={                    "Objects": [{"Key": self._id_to_key(x)} for x in ids],                    "Quiet": True,                },            )    async def _create_presigned_url(        self,        id: Id,        client_method: str,    ) -> AnyHttpUrl:        async with self.provider.client as client:            return await client.generate_presigned_url(                client_method,                Params={"Bucket": self.provider.bucket, "Key": self._id_to_key(id)},                ExpiresIn=DEFAULT_EXPIRY,            )    async def create_download_url(self, id: Id) -> AnyHttpUrl:        return await self._create_presigned_url(id, "get_object")    async def create_upload_url(self, id: Id) -> AnyHttpUrl:        return await self._create_presigned_url(id, "put_object")    async def download_file(self, id: Id, file_path: Path) -> None:        if file_path.exists():            raise FileExistsError()        try:            async with self.provider.client as client:                await client.download_file(                    Bucket=self.provider.bucket,                    Key=self._id_to_key(id),                    Filename=str(file_path),                )        except ClientError as e:            if e.response["Error"]["Code"] == "404":                file_path.unlink(missing_ok=True)                raise DoesNotExist("object")            else:                raise e    async def upload_file(self, id: Id, file_path: Path) -> None:        if not file_path.is_file():            raise FileNotFoundError()        async with self.provider.client as client:            await client.upload_file(                Bucket=self.provider.bucket,                Key=self._id_to_key(id),                Filename=str(file_path),            )    def filters_to_prefix(self, filters: List[Filter]) -> str:        if len(filters) == 0:            return self._id_to_key("")        elif len(filters) > 1:            raise NotImplementedError("More than 1 filter is not supported")        (filter,) = filters        if filter.field == "prefix":            assert len(filter.values) == 1            return self._id_to_key(filter.values[0])        else:            raise NotImplementedError(f"Unsupported filter '{filter.field}'")    async def remove_filtered(self, filters: List[Filter]) -> None:        kwargs = {            "Bucket": self.provider.bucket,            "MaxKeys": AWS_LIMIT,            "Prefix": self.filters_to_prefix(filters),        }        async with self.provider.client as client:            while True:                result = await client.list_objects_v2(**kwargs)                contents = result.get("Contents", [])                if contents:                    await client.delete_objects(                        Bucket=self.provider.bucket,                        Delete={                            "Objects": [{"Key": x["Key"]} for x in contents],                            "Quiet": True,                        },                    )                if len(contents) < AWS_LIMIT:                    break                kwargs["StartAfter"] = contents[-1]["Key"]
 |