Bladeren bron

Implement S3Gateway with multitenancy (#8)

Casper van der Wel 1 jaar geleden
bovenliggende
commit
05fc605f6b

+ 11 - 1
.github/workflows/test.yml

@@ -36,6 +36,15 @@ jobs:
         ports:
           - 5432:5432
 
+      s3:
+        image: minio/minio:edge-cicd
+        env:
+          MINIO_DOMAIN: localhost  # virtual hosted-style access
+          MINIO_ROOT_USER: cleanpython
+          MINIO_ROOT_PASSWORD: cleanpython
+        ports:
+          - 9000:9000
+
     steps:
       - uses: actions/checkout@v3
 
@@ -47,7 +56,7 @@ jobs:
       - name: Install python dependencies
         run: |
           pip install --disable-pip-version-check --upgrade pip setuptools
-          pip install -e .[dramatiq,fastapi,auth,celery,fluentbit,sql,test] ${{ matrix.pins }}
+          pip install -e .[dramatiq,fastapi,auth,celery,fluentbit,sql,s3,test] ${{ matrix.pins }}
           pip list
 
       - name: Run tests
@@ -63,3 +72,4 @@ jobs:
         run: pytest integration_tests
         env:
           POSTGRES_URL: 'postgres:postgres@localhost:5432'
+          S3_URL: 'http://localhost:9000'

+ 4 - 0
CHANGES.md

@@ -6,6 +6,10 @@
 
 - Don't use environment variables in setup_debugger.
 
+- Add Id type (replaces int), it can also be a string.
+
+- Added S3Gateway.
+
 
 ## 0.3.4 (2023-08-28)
 ---------------------

+ 4 - 3
clean_python/base/application/manage.py

@@ -8,6 +8,7 @@ from typing import Type
 from typing import TypeVar
 
 from clean_python.base.domain import Filter
+from clean_python.base.domain import Id
 from clean_python.base.domain import Json
 from clean_python.base.domain import Page
 from clean_python.base.domain import PageOptions
@@ -35,16 +36,16 @@ class Manage(Generic[T]):
         super().__init_subclass__()
         cls.entity = entity
 
-    async def retrieve(self, id: int) -> T:
+    async def retrieve(self, id: Id) -> T:
         return await self.repo.get(id)
 
     async def create(self, values: Json) -> T:
         return await self.repo.add(values)
 
-    async def update(self, id: int, values: Json) -> T:
+    async def update(self, id: Id, values: Json) -> T:
         return await self.repo.update(id, values)
 
-    async def destroy(self, id: int) -> bool:
+    async def destroy(self, id: Id) -> bool:
         return await self.repo.remove(id)
 
     async def list(self, params: Optional[PageOptions] = None) -> Page[T]:

+ 1 - 1
clean_python/base/domain/__init__.py

@@ -4,9 +4,9 @@ from .domain_service import *  # NOQA
 from .exceptions import *  # NOQA
 from .filter import *  # NOQA
 from .gateway import *  # NOQA
-from .json import *  # NOQA
 from .pagination import *  # NOQA
 from .repository import *  # NOQA
 from .root_entity import *  # NOQA
+from .types import *  # NOQA
 from .value import *  # NOQA
 from .value_object import *  # NOQA

+ 3 - 1
clean_python/base/domain/exceptions.py

@@ -9,6 +9,8 @@ from typing import Union
 from pydantic import create_model
 from pydantic import ValidationError
 
+from .types import Id
+
 __all__ = [
     "AlreadyExists",
     "Conflict",
@@ -22,7 +24,7 @@ __all__ = [
 
 
 class DoesNotExist(Exception):
-    def __init__(self, name: str, id: Optional[int] = None):
+    def __init__(self, name: str, id: Optional[Id] = None):
         super().__init__()
         self.name = name
         self.id = id

+ 5 - 4
clean_python/base/domain/gateway.py

@@ -8,8 +8,9 @@ from typing import Optional
 
 from .exceptions import DoesNotExist
 from .filter import Filter
-from .json import Json
 from .pagination import PageOptions
+from .types import Id
+from .types import Json
 
 __all__ = ["Gateway"]
 
@@ -26,7 +27,7 @@ class Gateway(ABC):
     async def exists(self, filters: List[Filter]) -> bool:
         return len(await self.filter(filters, params=PageOptions(limit=1))) > 0
 
-    async def get(self, id: int) -> Optional[Json]:
+    async def get(self, id: Id) -> Optional[Json]:
         result = await self.filter([Filter(field="id", values=[id])], params=None)
         return result[0] if result else None
 
@@ -38,7 +39,7 @@ class Gateway(ABC):
     ) -> Json:
         raise NotImplementedError()
 
-    async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
+    async def update_transactional(self, id: Id, func: Callable[[Json], Json]) -> Json:
         existing = await self.get(id)
         if existing is None:
             raise DoesNotExist("record", id)
@@ -52,5 +53,5 @@ class Gateway(ABC):
         except DoesNotExist:
             return await self.add(item)
 
-    async def remove(self, id: int) -> bool:
+    async def remove(self, id: Id) -> bool:
         raise NotImplementedError()

+ 1 - 0
clean_python/base/domain/pagination.py

@@ -17,6 +17,7 @@ class PageOptions(BaseModel):
     offset: int = 0
     order_by: str = "id"
     ascending: bool = True
+    cursor: Optional[str] = None
 
 
 class Page(BaseModel, Generic[T]):

+ 5 - 4
clean_python/base/domain/repository.py

@@ -11,10 +11,11 @@ from typing import Union
 from .exceptions import DoesNotExist
 from .filter import Filter
 from .gateway import Gateway
-from .json import Json
 from .pagination import Page
 from .pagination import PageOptions
 from .root_entity import RootEntity
+from .types import Id
+from .types import Json
 
 __all__ = ["Repository"]
 
@@ -58,7 +59,7 @@ class Repository(Generic[T]):
             items=[self.entity(**x) for x in records],
         )
 
-    async def get(self, id: int) -> T:
+    async def get(self, id: Id) -> T:
         res = await self.gateway.get(id)
         if res is None:
             raise DoesNotExist("object", id)
@@ -71,7 +72,7 @@ class Repository(Generic[T]):
         created = await self.gateway.add(item.model_dump())
         return self.entity(**created)
 
-    async def update(self, id: int, values: Json) -> T:
+    async def update(self, id: Id, values: Json) -> T:
         if not values:
             return await self.get(id)
         updated = await self.gateway.update_transactional(
@@ -84,7 +85,7 @@ class Repository(Generic[T]):
         upserted = await self.gateway.upsert(values)
         return self.entity(**upserted)
 
-    async def remove(self, id: int) -> bool:
+    async def remove(self, id: Id) -> bool:
         return await self.gateway.remove(id)
 
     async def count(self, filters: List[Filter]) -> int:

+ 2 - 1
clean_python/base/domain/root_entity.py

@@ -7,6 +7,7 @@ from typing import Type
 from typing import TypeVar
 
 from .exceptions import BadRequest
+from .types import Id
 from .value_object import ValueObject
 
 __all__ = ["RootEntity", "now"]
@@ -21,7 +22,7 @@ T = TypeVar("T", bound="RootEntity")
 
 
 class RootEntity(ValueObject):
-    id: Optional[int] = None
+    id: Optional[Id] = None
     created_at: datetime
     updated_at: datetime
 

+ 3 - 1
clean_python/base/domain/json.py → clean_python/base/domain/types.py

@@ -2,8 +2,10 @@
 
 from typing import Any
 from typing import Dict
+from typing import Union
 
-__all__ = ["Json"]
+__all__ = ["Json", "Id"]
 
 
 Json = Dict[str, Any]
+Id = Union[int, str]

+ 2 - 1
clean_python/base/domain/value_object.py

@@ -9,6 +9,7 @@ from pydantic import ConfigDict
 from pydantic import ValidationError
 
 from .exceptions import BadRequest
+from .types import Id
 
 __all__ = ["ValueObject", "ValueObjectWithId"]
 
@@ -46,7 +47,7 @@ K = TypeVar("K", bound="ValueObjectWithId")
 
 
 class ValueObjectWithId(ValueObject):
-    id: Optional[int] = None
+    id: Optional[Id] = None
 
     def update(self: K, **values) -> K:
         if "id" in values and self.id is not None and values["id"] != self.id:

+ 2 - 1
clean_python/base/infrastructure/in_memory_gateway.py

@@ -10,6 +10,7 @@ from clean_python.base.domain import Conflict
 from clean_python.base.domain import DoesNotExist
 from clean_python.base.domain import Filter
 from clean_python.base.domain import Gateway
+from clean_python.base.domain import Id
 from clean_python.base.domain import Json
 from clean_python.base.domain import PageOptions
 
@@ -74,7 +75,7 @@ class InMemoryGateway(Gateway):
         existing.update(item)
         return deepcopy(existing)
 
-    async def remove(self, id: int) -> bool:
+    async def remove(self, id: Id) -> bool:
         if id not in self.data:
             return False
         del self.data[id]

+ 2 - 0
clean_python/s3/__init__.py

@@ -0,0 +1,2 @@
+from .s3_gateway import *  # NOQA
+from .s3_provider import *  # NOQA

+ 195 - 0
clean_python/s3/s3_gateway.py

@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import logging
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import inject
+from botocore.exceptions import ClientError
+from pydantic import AnyHttpUrl
+
+from clean_python import ctx
+from clean_python import DoesNotExist
+from clean_python import Filter
+from clean_python import Gateway
+from clean_python import Id
+from clean_python import Json
+from clean_python import PageOptions
+
+from .s3_provider import S3BucketProvider
+
+DEFAULT_EXPIRY = 3600  # in seconds
+DEFAULT_TIMEOUT = 1.0
+AWS_LIMIT = 1000  # max s3 keys per request
+
+
+__all__ = ["S3Gateway"]
+
+logger = logging.getLogger(__name__)
+
+
+class S3Gateway(Gateway):
+    """The interface to S3 Buckets.
+
+    The standard Gateway interface is only partially implemented:
+
+    - get() and filter() return metadata
+    - add(), update(), upsert() are not implemented
+    - remove() works as expected
+
+    For actually getting the object data either use the download_file()
+    or upload_file() or create a presigned url and hand that over to
+    the client.
+    """
+
+    def __init__(
+        self,
+        provider_override: Optional[S3BucketProvider] = None,
+        multitenant: bool = False,
+    ):
+        self.provider_override = provider_override
+        self.multitenant = multitenant
+
+    @property
+    def provider(self):
+        return self.provider_override or inject.instance(S3BucketProvider)
+
+    def _id_to_key(self, id: Id) -> str:
+        if not self.multitenant:
+            return str(id)
+        if ctx.tenant is None:
+            raise RuntimeError(f"{self.__class__} requires a tenant in the context")
+        return f"tenant-{ctx.tenant.id}/{id}"
+
+    def _key_to_id(self, key: str) -> Id:
+        return key.split("/", 1)[1] if self.multitenant else key
+
+    async def get(self, id: Id) -> Optional[Json]:
+        async with self.provider.client as client:
+            try:
+                result = await client.head_object(
+                    Bucket=self.provider.bucket, Key=self._id_to_key(id)
+                )
+            except ClientError as e:
+                if e.response["Error"]["Code"] == "404":
+                    return None
+                else:
+                    raise e
+        return {
+            "id": str(id),
+            "last_modified": result["LastModified"],
+            "etag": result["ETag"].strip('"'),
+            "size": result["ContentLength"],
+        }
+
+    async def filter(
+        self,
+        filters: List[Filter],
+        params: Optional[PageOptions] = PageOptions(limit=AWS_LIMIT),
+    ) -> List[Json]:
+        assert params is not None, "pagination is required for S3Gateway"
+        assert params.limit <= AWS_LIMIT, f"max {AWS_LIMIT} keys for S3Gateway"
+        assert params.offset == 0, "no 'offset' pagination for S3Gateway"
+        assert params.order_by == "id", "can order by 'id' only for S3Gateway"
+        kwargs = {
+            "Bucket": self.provider.bucket,
+            "MaxKeys": params.limit,
+        }
+        for filter in filters:
+            if filter.field == "prefix":
+                (kwargs["Prefix"],) = filter.values
+            else:
+                raise NotImplementedError(f"Unsupported filter field '{filter.field}'")
+        if self.multitenant:
+            kwargs["Prefix"] = self._id_to_key(kwargs.get("Prefix", ""))
+        if params.cursor is not None:
+            kwargs["StartAfter"] = self._id_to_key(params.cursor)
+        async with self.provider.client as client:
+            result = await client.list_objects_v2(**kwargs)
+        # Example response:
+        #     {
+        #         'Key': 'object-in-s3',
+        #         'LastModified': datetime.datetime(..., tzinfo=utc),
+        #         'ETag': '"acbd18db4cc2f85cedef654fccc4a4d8"',
+        #         'Size': 3, 'StorageClass':
+        #         'STANDARD',
+        #         'Owner': {...}
+        #     }
+        return [
+            {
+                "id": self._key_to_id(x["Key"]),
+                "last_modified": x["LastModified"],
+                "etag": x["ETag"].strip('"'),
+                "size": x["Size"],
+            }
+            for x in result.get("Contents", [])
+        ]
+
+    async def remove(self, id: Id) -> bool:
+        async with self.provider.client as client:
+            await client.delete_object(
+                Bucket=self.provider.bucket,
+                Key=self._id_to_key(id),
+            )
+        # S3 doesn't tell us if the object was there in the first place
+        return True
+
+    async def remove_multiple(self, ids: List[Id]) -> None:
+        if len(ids) == 0:
+            return
+        assert len(ids) <= AWS_LIMIT, f"max {AWS_LIMIT} keys for S3Gateway"
+        async with self.provider.client as client:
+            await client.delete_objects(
+                Bucket=self.provider.bucket,
+                Delete={
+                    "Objects": [{"Key": self._id_to_key(x)} for x in ids],
+                    "Quiet": True,
+                },
+            )
+
+    async def _create_presigned_url(
+        self,
+        id: Id,
+        client_method: str,
+    ) -> AnyHttpUrl:
+        async with self.provider.client as client:
+            return await client.generate_presigned_url(
+                client_method,
+                Params={"Bucket": self.provider.bucket, "Key": self._id_to_key(id)},
+                ExpiresIn=DEFAULT_EXPIRY,
+            )
+
+    async def create_download_url(self, id: Id) -> AnyHttpUrl:
+        return await self._create_presigned_url(id, "get_object")
+
+    async def create_upload_url(self, id: Id) -> AnyHttpUrl:
+        return await self._create_presigned_url(id, "put_object")
+
+    async def download_file(self, id: Id, file_path: Path) -> None:
+        if file_path.exists():
+            raise FileExistsError()
+        try:
+            async with self.provider.client as client:
+                await client.download_file(
+                    Bucket=self.provider.bucket,
+                    Key=self._id_to_key(id),
+                    Filename=str(file_path),
+                )
+        except ClientError as e:
+            if e.response["Error"]["Code"] == "404":
+                file_path.unlink(missing_ok=True)
+                raise DoesNotExist("object")
+            else:
+                raise e
+
+    async def upload_file(self, id: Id, file_path: Path) -> None:
+        if not file_path.is_file():
+            raise FileNotFoundError()
+        async with self.provider.client as client:
+            await client.upload_file(
+                Bucket=self.provider.bucket,
+                Key=self._id_to_key(id),
+                Filename=str(file_path),
+            )

+ 55 - 0
clean_python/s3/s3_provider.py

@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import logging
+from typing import Optional
+from typing import TYPE_CHECKING
+
+import aioboto3
+from botocore.client import Config
+
+from clean_python import ValueObject
+
+if TYPE_CHECKING:
+    from types_aiobotocore_s3.client import S3Client
+
+__all__ = ["S3BucketOptions", "S3BucketProvider"]
+
+logger = logging.getLogger(__name__)
+
+
+class S3BucketOptions(ValueObject):
+    url: str
+    access_key: str
+    secret_key: str
+    bucket: str
+    region: Optional[str] = None
+
+
+class S3BucketProvider:
+    def __init__(self, options: S3BucketOptions):
+        self.options = options
+
+    @property
+    def bucket(self) -> str:
+        return self.options.bucket
+
+    @property
+    def client(self) -> "S3Client":
+        session = aioboto3.Session()
+        return session.client(
+            "s3",
+            endpoint_url=self.options.url,
+            aws_access_key_id=self.options.access_key,
+            aws_secret_access_key=self.options.secret_key,
+            region_name=self.options.region,
+            config=Config(
+                s3={"addressing_style": "virtual"},  # "path" will become deprecated
+                signature_version="s3v4",  # for minio
+                retries={
+                    "max_attempts": 4,  # 1 try and up to 3 retries
+                    "mode": "adaptive",
+                },
+            ),
+            use_ssl=self.options.url.startswith("https"),
+        )

+ 5 - 4
clean_python/sql/sql_gateway.py

@@ -29,6 +29,7 @@ from clean_python import ctx
 from clean_python import DoesNotExist
 from clean_python import Filter
 from clean_python import Gateway
+from clean_python import Id
 from clean_python import Json
 from clean_python import PageOptions
 
@@ -152,7 +153,7 @@ class SQLGateway(Gateway):
             await transaction.set_related(item, result[0])
         return result[0]
 
-    async def _select_for_update(self, id: int) -> Json:
+    async def _select_for_update(self, id: Id) -> Json:
         async with self.transaction() as transaction:
             result = await transaction.execute(
                 select(self.table).with_for_update().where(self._id_filter_to_sql(id)),
@@ -162,7 +163,7 @@ class SQLGateway(Gateway):
             await transaction.get_related(result)
         return result[0]
 
-    async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
+    async def update_transactional(self, id: Id, func: Callable[[Json], Json]) -> Json:
         async with self.transaction() as transaction:
             existing = await transaction._select_for_update(id)
             updated = func(existing)
@@ -186,7 +187,7 @@ class SQLGateway(Gateway):
             await transaction.set_related(item, result[0])
         return result[0]
 
-    async def remove(self, id) -> bool:
+    async def remove(self, id: Id) -> bool:
         query = (
             delete(self.table)
             .where(self._id_filter_to_sql(id))
@@ -214,7 +215,7 @@ class SQLGateway(Gateway):
             qs.append(self.table.c.tenant == self.current_tenant)
         return and_(*qs)
 
-    def _id_filter_to_sql(self, id: int) -> ColumnElement:
+    def _id_filter_to_sql(self, id: Id) -> ColumnElement:
         return self._filters_to_sql([Filter(field="id", values=[id])])
 
     async def filter(

+ 10 - 1
docker-compose.yaml

@@ -2,10 +2,19 @@ version: "3.8"
 
 services:
 
-  db:
+  postgres:
     image: postgres:14-alpine
     environment:
       POSTGRES_PASSWORD: "postgres"
     # command: ["postgres", "-c", "log_connections=all", "-c", "log_disconnections=all", "-c", "log_statement=all", "-c", "log_destination=stderr"]
     ports:
       - "5432:5432"
+
+  s3:
+    image: minio/minio:edge-cicd
+    environment:
+      MINIO_DOMAIN: localhost # virtual hosted-style access
+      MINIO_ROOT_USER: cleanpython
+      MINIO_ROOT_PASSWORD: cleanpython
+    ports:
+      - "9000:9000"

+ 5 - 0
integration_tests/conftest.py

@@ -34,3 +34,8 @@ def event_loop(request):
 @pytest.fixture(scope="session")
 async def postgres_url():
     return os.environ.get("POSTGRES_URL", "postgres:postgres@localhost:5432")
+
+
+@pytest.fixture(scope="session")
+async def s3_url():
+    return os.environ.get("S3_URL", "http://localhost:9000")

+ 202 - 0
integration_tests/test_s3_gateway.py

@@ -0,0 +1,202 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import io
+from datetime import datetime
+
+import boto3
+import pytest
+from botocore.exceptions import ClientError
+
+from clean_python import DoesNotExist
+from clean_python import Filter
+from clean_python import PageOptions
+from clean_python.s3 import S3BucketOptions
+from clean_python.s3 import S3BucketProvider
+from clean_python.s3 import S3Gateway
+
+
+@pytest.fixture(scope="session")
+def s3_settings(s3_url):
+    minio_settings = {
+        "url": s3_url,
+        "access_key": "cleanpython",
+        "secret_key": "cleanpython",
+        "bucket": "cleanpython-test",
+        "region": None,
+    }
+    if not minio_settings["bucket"].endswith("-test"):  # type: ignore
+        pytest.exit("Not running against a test minio bucket?! 😱")
+    return minio_settings.copy()
+
+
+@pytest.fixture(scope="session")
+def s3_bucket(s3_settings):
+    s3 = boto3.resource(
+        "s3",
+        endpoint_url=s3_settings["url"],
+        aws_access_key_id=s3_settings["access_key"],
+        aws_secret_access_key=s3_settings["secret_key"],
+    )
+    bucket = s3.Bucket(s3_settings["bucket"])
+
+    # ensure existence
+    try:
+        bucket.create()
+    except ClientError as e:
+        if "BucketAlreadyOwnedByYou" in str(e):
+            pass
+    return bucket
+
+
+@pytest.fixture
+def s3_provider(s3_bucket, s3_settings):
+    # wipe contents before each test
+    s3_bucket.objects.all().delete()
+    return S3BucketProvider(S3BucketOptions(**s3_settings))
+
+
+@pytest.fixture
+def s3_gateway(s3_provider):
+    return S3Gateway(s3_provider)
+
+
+@pytest.fixture
+def object_in_s3(s3_bucket):
+    s3_bucket.upload_fileobj(io.BytesIO(b"foo"), "object-in-s3")
+    return "object-in-s3"
+
+
+@pytest.fixture
+def local_file(tmp_path):
+    path = tmp_path / "test-upload.txt"
+    path.write_bytes(b"foo")
+    return path
+
+
+async def test_upload_file(s3_gateway: S3Gateway, local_file):
+    object_name = "test-upload-file"
+
+    await s3_gateway.upload_file(object_name, local_file)
+
+    assert (await s3_gateway.get(object_name))["size"] == 3
+
+
+async def test_upload_file_does_not_exist(s3_gateway: S3Gateway, tmp_path):
+    path = tmp_path / "test-upload.txt"
+    object_name = "test-upload-file"
+
+    with pytest.raises(FileNotFoundError):
+        await s3_gateway.upload_file(object_name, path)
+
+
+async def test_download_file(s3_gateway: S3Gateway, object_in_s3, tmp_path):
+    path = tmp_path / "test-download.txt"
+
+    await s3_gateway.download_file(object_in_s3, path)
+
+    assert path.read_bytes() == b"foo"
+
+
+async def test_download_file_path_already_exists(
+    s3_gateway: S3Gateway, object_in_s3, tmp_path
+):
+    path = tmp_path / "test-download.txt"
+    path.write_bytes(b"bar")
+
+    with pytest.raises(FileExistsError):
+        await s3_gateway.download_file(object_in_s3, path)
+
+    assert path.read_bytes() == b"bar"
+
+
+async def test_download_file_does_not_exist(s3_gateway: S3Gateway, s3_bucket, tmp_path):
+    path = tmp_path / "test-download-does-not-exist.txt"
+
+    with pytest.raises(DoesNotExist):
+        await s3_gateway.download_file("some-nonexisting", path)
+
+    assert not path.exists()
+
+
+async def test_remove(s3_gateway: S3Gateway, s3_bucket, object_in_s3):
+    await s3_gateway.remove(object_in_s3)
+
+    assert await s3_gateway.get(object_in_s3) is None
+
+
+async def test_remove_does_not_exist(s3_gateway: S3Gateway, s3_bucket):
+    await s3_gateway.remove("non-existing")
+
+
+@pytest.fixture
+def multiple_objects(s3_bucket):
+    s3_bucket.upload_fileobj(io.BytesIO(b"a"), "raster-1/bla")
+    s3_bucket.upload_fileobj(io.BytesIO(b"ab"), "raster-2/bla")
+    s3_bucket.upload_fileobj(io.BytesIO(b"abc"), "raster-2/foo")
+    s3_bucket.upload_fileobj(io.BytesIO(b"abcde"), "raster-2/bz")
+    return ["raster-1/bla", "raster-2/bla", "raster-2/foo", "raster-2/bz"]
+
+
+async def test_remove_multiple(s3_gateway: S3Gateway, multiple_objects):
+    await s3_gateway.remove_multiple(multiple_objects[:2])
+
+    for key in multiple_objects[:2]:
+        assert await s3_gateway.get(key) is None
+
+    for key in multiple_objects[2:]:
+        assert await s3_gateway.get(key) is not None
+
+
+async def test_remove_multiple_empty_list(s3_gateway: S3Gateway, s3_bucket):
+    await s3_gateway.remove_multiple([])
+
+
+async def test_filter(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter([], params=PageOptions(limit=10))
+    assert len(actual) == 4
+    assert actual[0]["id"] == "raster-1/bla"
+    assert isinstance(actual[0]["last_modified"], datetime)
+    assert actual[0]["etag"] == "0cc175b9c0f1b6a831c399e269772661"
+    assert actual[0]["size"] == 1
+
+
+async def test_filter_empty(s3_gateway: S3Gateway, s3_bucket):
+    actual = await s3_gateway.filter([], params=PageOptions(limit=10))
+    assert actual == []
+
+
+async def test_filter_with_prefix(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter(
+        [Filter(field="prefix", values=["raster-2/"])], params=PageOptions(limit=10)
+    )
+    assert len(actual) == 3
+
+
+async def test_filter_with_limit(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter([], params=PageOptions(limit=2))
+    assert len(actual) == 2
+    assert actual[0]["id"] == "raster-1/bla"
+    assert actual[1]["id"] == "raster-2/bla"
+
+
+async def test_filter_with_cursor(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter(
+        [], params=PageOptions(limit=3, cursor="raster-2/bla")
+    )
+    assert len(actual) == 2
+    assert actual[0]["id"] == "raster-2/bz"
+    assert actual[1]["id"] == "raster-2/foo"
+
+
+async def test_get(s3_gateway: S3Gateway, object_in_s3):
+    actual = await s3_gateway.get(object_in_s3)
+    assert actual["id"] == "object-in-s3"
+    assert isinstance(actual["last_modified"], datetime)
+    assert actual["etag"] == "acbd18db4cc2f85cedef654fccc4a4d8"
+    assert actual["size"] == 3
+
+
+async def test_get_does_not_exist(s3_gateway: S3Gateway):
+    actual = await s3_gateway.get("non-existing")
+    assert actual is None

+ 182 - 0
integration_tests/test_s3_gateway_multitenant.py

@@ -0,0 +1,182 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import io
+from datetime import datetime
+
+import boto3
+import pytest
+from botocore.exceptions import ClientError
+
+from clean_python import ctx
+from clean_python import DoesNotExist
+from clean_python import Filter
+from clean_python import PageOptions
+from clean_python import Tenant
+from clean_python.s3 import S3BucketOptions
+from clean_python.s3 import S3BucketProvider
+from clean_python.s3 import S3Gateway
+
+
+@pytest.fixture(scope="session")
+def s3_settings(s3_url):
+    minio_settings = {
+        "url": s3_url,
+        "access_key": "cleanpython",
+        "secret_key": "cleanpython",
+        "bucket": "cleanpython-test",
+        "region": None,
+    }
+    if not minio_settings["bucket"].endswith("-test"):  # type: ignore
+        pytest.exit("Not running against a test minio bucket?! 😱")
+    return minio_settings.copy()
+
+
+@pytest.fixture(scope="session")
+def s3_bucket(s3_settings):
+    s3 = boto3.resource(
+        "s3",
+        endpoint_url=s3_settings["url"],
+        aws_access_key_id=s3_settings["access_key"],
+        aws_secret_access_key=s3_settings["secret_key"],
+    )
+    bucket = s3.Bucket(s3_settings["bucket"])
+
+    # ensure existence
+    try:
+        bucket.create()
+    except ClientError as e:
+        if "BucketAlreadyOwnedByYou" in str(e):
+            pass
+    return bucket
+
+
+@pytest.fixture
+def s3_provider(s3_bucket, s3_settings):
+    # wipe contents before each test
+    s3_bucket.objects.all().delete()
+    # set up a tenant
+    ctx.tenant = Tenant(id=22, name="foo")
+    return S3BucketProvider(S3BucketOptions(**s3_settings))
+
+
+@pytest.fixture
+def s3_gateway(s3_provider):
+    return S3Gateway(s3_provider, multitenant=True)
+
+
+@pytest.fixture
+def object_in_s3(s3_bucket):
+    s3_bucket.upload_fileobj(io.BytesIO(b"foo"), "tenant-22/object-in-s3")
+    return "object-in-s3"
+
+
+@pytest.fixture
+def object_in_s3_other_tenant(s3_bucket):
+    s3_bucket.upload_fileobj(io.BytesIO(b"foo"), "tenant-222/object-in-s3")
+    return "object-in-s3"
+
+
+@pytest.fixture
+def local_file(tmp_path):
+    path = tmp_path / "test-upload.txt"
+    path.write_bytes(b"foo")
+    return path
+
+
+async def test_upload_file_uses_tenant(s3_gateway: S3Gateway, local_file, s3_bucket):
+    object_name = "test-upload-file"
+
+    await s3_gateway.upload_file(object_name, local_file)
+
+    assert s3_bucket.Object("tenant-22/test-upload-file").content_length == 3
+
+
+async def test_download_file_uses_tenant(s3_gateway: S3Gateway, object_in_s3, tmp_path):
+    path = tmp_path / "test-download.txt"
+
+    await s3_gateway.download_file(object_in_s3, path)
+
+    assert path.read_bytes() == b"foo"
+
+
+async def test_download_file_different_tenant(
+    s3_gateway: S3Gateway, s3_bucket, tmp_path, object_in_s3_other_tenant
+):
+    path = tmp_path / "test-download.txt"
+
+    with pytest.raises(DoesNotExist):
+        await s3_gateway.download_file("object-in-s3", path)
+
+    assert not path.exists()
+
+
+async def test_remove_uses_tenant(s3_gateway: S3Gateway, s3_bucket, object_in_s3):
+    await s3_gateway.remove(object_in_s3)
+
+    assert await s3_gateway.get(object_in_s3) is None
+
+
+async def test_remove_other_tenant(
+    s3_gateway: S3Gateway, s3_bucket, object_in_s3_other_tenant
+):
+    await s3_gateway.remove(object_in_s3_other_tenant)
+
+    # it is still there
+    assert s3_bucket.Object("tenant-222/object-in-s3").content_length == 3
+
+
+@pytest.fixture
+def multiple_objects(s3_bucket):
+    s3_bucket.upload_fileobj(io.BytesIO(b"a"), "tenant-22/raster-1/bla")
+    s3_bucket.upload_fileobj(io.BytesIO(b"ab"), "tenant-222/raster-2/bla")
+    s3_bucket.upload_fileobj(io.BytesIO(b"abc"), "tenant-22/raster-2/foo")
+    s3_bucket.upload_fileobj(io.BytesIO(b"abcde"), "tenant-22/raster-2/bz")
+    return ["raster-1/bla", "raster-2/bla", "raster-2/foo", "raster-2/bz"]
+
+
+async def test_remove_multiple_multitenant(
+    s3_gateway: S3Gateway, multiple_objects, s3_bucket
+):
+    await s3_gateway.remove_multiple(multiple_objects[:2])
+
+    assert await s3_gateway.get(multiple_objects[0]) is None
+
+    # the other-tenant object is still there
+    assert s3_bucket.Object("tenant-222/raster-2/bla").content_length == 2
+
+
+async def test_filter_multitenant(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter([], params=PageOptions(limit=10))
+    assert len(actual) == 3
+    assert actual[0]["id"] == "raster-1/bla"
+
+
+async def test_filter_with_prefix_multitenant(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter(
+        [Filter(field="prefix", values=["raster-2/"])], params=PageOptions(limit=10)
+    )
+    assert len(actual) == 2
+    assert actual[0]["id"] == "raster-2/bz"
+    assert actual[1]["id"] == "raster-2/foo"
+
+
+async def test_filter_with_cursor_multitenant(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter(
+        [], params=PageOptions(limit=3, cursor="raster-2/bz")
+    )
+    assert len(actual) == 1
+    assert actual[0]["id"] == "raster-2/foo"
+
+
+async def test_get_multitenant(s3_gateway: S3Gateway, object_in_s3):
+    actual = await s3_gateway.get(object_in_s3)
+    assert actual["id"] == object_in_s3
+    assert isinstance(actual["last_modified"], datetime)
+    assert actual["etag"] == "acbd18db4cc2f85cedef654fccc4a4d8"
+    assert actual["size"] == 3
+
+
+async def test_get_other_tenant(s3_gateway: S3Gateway, object_in_s3_other_tenant):
+    actual = await s3_gateway.get(object_in_s3_other_tenant)
+    assert actual is None

+ 1 - 0
pyproject.toml

@@ -27,6 +27,7 @@ auth = ["pyjwt[crypto]==2.6.0"]
 celery = ["pika"]
 fluentbit = ["fluent-logger"]
 sql = ["sqlalchemy==2.*", "asyncpg"]
+s3 = ["aioboto3", "boto3"]
 
 [project.urls]
 homepage = "https://github.com/nens/clean-python"