Browse Source

Generic implementation of S3Gateway prefixes (#16)

Casper van der Wel 1 year ago
parent
commit
199c9e31f4

+ 3 - 1
CHANGES.md

@@ -4,7 +4,9 @@
 0.5.1 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Added `S3Gateway.remove_filtered`.
+
+- Added `clean_python.s3.KeyMapper`.
 
 
 0.5.0 (2023-09-12)

+ 3 - 1
clean_python/base/domain/pagination.py

@@ -7,6 +7,8 @@ from typing import TypeVar
 
 from pydantic import BaseModel
 
+from .types import Id
+
 __all__ = ["Page", "PageOptions"]
 
 T = TypeVar("T")
@@ -17,7 +19,7 @@ class PageOptions(BaseModel):
     offset: int = 0
     order_by: str = "id"
     ascending: bool = True
-    cursor: Optional[str] = None
+    cursor: Optional[Id] = None
 
 
 class Page(BaseModel, Generic[T]):

+ 1 - 0
clean_python/s3/__init__.py

@@ -1,2 +1,3 @@
+from .key_mapper import *  # NOQA
 from .s3_gateway import *  # NOQA
 from .s3_provider import *  # NOQA

+ 63 - 0
clean_python/s3/key_mapper.py

@@ -0,0 +1,63 @@
+import re
+from typing import Tuple
+
+from pydantic import field_validator
+
+from clean_python import DomainService
+from clean_python import Id
+
+__all__ = ["KeyMapper"]
+
+
+def _maybe_coerce_int(x: str) -> Id:
+    try:
+        return int(x)
+    except ValueError:
+        return x
+
+
+class KeyMapper(DomainService):
+    """Maps one or multiple ids to a string and vice versa.
+
+    The mapping is configured using a python formatting string with standard
+    {} placeholders. Additionally, the key can be prefixed with a tenant id
+    when multitenant=True.
+    """
+
+    pattern: str = "{}"
+
+    @field_validator("pattern")
+    @classmethod
+    def validate_pattern(cls, v):
+        if isinstance(v, str):
+            assert not v.startswith("/"), "pattern should not start with '/'"
+            assert v.endswith("{}"), "pattern cannot have a suffix"
+            try:
+                v.format(*((2,) * v.count("{}")))
+            except KeyError:
+                raise ValueError("invalid pattern")
+        return v
+
+    @property
+    def n_placeholders(self) -> int:
+        return self.pattern.count("{}")
+
+    def get_named_pattern(self, *names: str) -> str:
+        return self.pattern.format(*[f"{{{x}}}" for x in names])
+
+    @property
+    def regex(self) -> str:
+        return "^" + self.pattern.replace("{}", "(.+)") + "$"
+
+    def to_key(self, *args: Id) -> str:
+        assert len(args) == self.n_placeholders
+        return self.pattern.format(*args)
+
+    def to_key_prefix(self, *args: Id) -> str:
+        return self.to_key(*(args + ("",)))
+
+    def from_key(self, key: str) -> Tuple[Id, ...]:
+        match = re.fullmatch(self.regex, key)
+        if match is None:
+            raise ValueError("key does not match expected pattern")
+        return tuple(_maybe_coerce_int(x) for x in match.groups())

+ 35 - 7
clean_python/s3/s3_gateway.py

@@ -96,14 +96,8 @@ class S3Gateway(Gateway):
         kwargs = {
             "Bucket": self.provider.bucket,
             "MaxKeys": params.limit,
+            "Prefix": self.filters_to_prefix(filters),
         }
-        for filter in filters:
-            if filter.field == "prefix":
-                (kwargs["Prefix"],) = filter.values
-            else:
-                raise NotImplementedError(f"Unsupported filter field '{filter.field}'")
-        if self.multitenant:
-            kwargs["Prefix"] = self._id_to_key(kwargs.get("Prefix", ""))
         if params.cursor is not None:
             kwargs["StartAfter"] = self._id_to_key(params.cursor)
         async with self.provider.client as client:
@@ -193,3 +187,37 @@ class S3Gateway(Gateway):
                 Key=self._id_to_key(id),
                 Filename=str(file_path),
             )
+
+    def filters_to_prefix(self, filters: List[Filter]) -> str:
+        if len(filters) == 0:
+            return self._id_to_key("")
+        elif len(filters) > 1:
+            raise NotImplementedError("More than 1 filter is not supported")
+        (filter,) = filters
+        if filter.field == "prefix":
+            assert len(filter.values) == 1
+            return self._id_to_key(filter.values[0])
+        else:
+            raise NotImplementedError(f"Unsupported filter '{filter.field}'")
+
+    async def remove_filtered(self, filters: List[Filter]) -> None:
+        kwargs = {
+            "Bucket": self.provider.bucket,
+            "MaxKeys": AWS_LIMIT,
+            "Prefix": self.filters_to_prefix(filters),
+        }
+        async with self.provider.client as client:
+            while True:
+                result = await client.list_objects_v2(**kwargs)
+                contents = result.get("Contents", [])
+                if contents:
+                    await client.delete_objects(
+                        Bucket=self.provider.bucket,
+                        Delete={
+                            "Objects": [{"Key": x["Key"]} for x in contents],
+                            "Quiet": True,
+                        },
+                    )
+                if len(contents) < AWS_LIMIT:
+                    break
+                kwargs["StartAfter"] = contents[-1]["Key"]

+ 33 - 7
integration_tests/test_s3_gateway.py

@@ -3,6 +3,7 @@
 
 import io
 from datetime import datetime
+from unittest import mock
 
 import boto3
 import pytest
@@ -152,6 +153,30 @@ async def test_remove_multiple_empty_list(s3_gateway: S3Gateway, s3_bucket):
     await s3_gateway.remove_multiple([])
 
 
+async def test_remove_filtered_all(s3_gateway: S3Gateway, multiple_objects):
+    await s3_gateway.remove_filtered([])
+
+    for key in multiple_objects:
+        assert await s3_gateway.get(key) is None
+
+
+async def test_remove_filtered_prefix(s3_gateway: S3Gateway, multiple_objects):
+    await s3_gateway.remove_filtered([Filter(field="prefix", values=["raster-2/"])])
+
+    assert await s3_gateway.get(multiple_objects[0]) is not None
+    for key in multiple_objects[1:]:
+        assert await s3_gateway.get(key) is None
+
+
+@mock.patch("clean_python.s3.s3_gateway.AWS_LIMIT", new=1)
+async def test_remove_filtered_pagination(s3_gateway: S3Gateway, multiple_objects):
+    await s3_gateway.remove_filtered([Filter(field="prefix", values=["raster-2/"])])
+
+    assert await s3_gateway.get(multiple_objects[0]) is not None
+    for key in multiple_objects[1:]:
+        assert await s3_gateway.get(key) is None
+
+
 async def test_filter(s3_gateway: S3Gateway, multiple_objects):
     actual = await s3_gateway.filter([], params=PageOptions(limit=10))
     assert len(actual) == 4
@@ -166,13 +191,6 @@ async def test_filter_empty(s3_gateway: S3Gateway, s3_bucket):
     assert actual == []
 
 
-async def test_filter_with_prefix(s3_gateway: S3Gateway, multiple_objects):
-    actual = await s3_gateway.filter(
-        [Filter(field="prefix", values=["raster-2/"])], params=PageOptions(limit=10)
-    )
-    assert len(actual) == 3
-
-
 async def test_filter_with_limit(s3_gateway: S3Gateway, multiple_objects):
     actual = await s3_gateway.filter([], params=PageOptions(limit=2))
     assert len(actual) == 2
@@ -189,6 +207,14 @@ async def test_filter_with_cursor(s3_gateway: S3Gateway, multiple_objects):
     assert actual[1]["id"] == "raster-2/foo"
 
 
+async def test_filter_by_prefix(s3_gateway: S3Gateway, multiple_objects):
+    actual = await s3_gateway.filter([Filter(field="prefix", values=["raster-1/"])])
+    assert len(actual) == 1
+
+    actual = await s3_gateway.filter([Filter(field="prefix", values=["raster-2/"])])
+    assert len(actual) == 3
+
+
 async def test_get(s3_gateway: S3Gateway, object_in_s3):
     actual = await s3_gateway.get(object_in_s3)
     assert actual["id"] == "object-in-s3"

+ 24 - 0
integration_tests/test_s3_gateway_multitenant.py

@@ -180,3 +180,27 @@ async def test_get_multitenant(s3_gateway: S3Gateway, object_in_s3):
 async def test_get_other_tenant(s3_gateway: S3Gateway, object_in_s3_other_tenant):
     actual = await s3_gateway.get(object_in_s3_other_tenant)
     assert actual is None
+
+
+async def test_remove_filtered_all(s3_gateway: S3Gateway, multiple_objects):
+    await s3_gateway.remove_filtered([])
+
+    # tenant 22 is completely wiped
+    for i in (0, 2, 3):
+        assert await s3_gateway.get(multiple_objects[i]) is None
+
+    # object of tenant 222 is still there
+    ctx.tenant = Tenant(id=222, name="other")
+    await s3_gateway.get("raster-2/bla") is not None
+
+
+async def test_remove_filtered_prefix(s3_gateway: S3Gateway, multiple_objects):
+    await s3_gateway.remove_filtered([Filter(field="prefix", values=["raster-2/"])])
+
+    assert await s3_gateway.get("raster-1/bla") is not None
+    assert await s3_gateway.get("raster-2/foo") is None
+    assert await s3_gateway.get("raster-2/bz") is None
+
+    # object of tenant 222 is still there
+    ctx.tenant = Tenant(id=222, name="other")
+    await s3_gateway.get("raster-2/bla") is not None

+ 63 - 0
tests/s3/test_key_mapper.py

@@ -0,0 +1,63 @@
+import pytest
+from pydantic import ValidationError
+
+from clean_python.s3 import KeyMapper
+
+
+@pytest.mark.parametrize(
+    "pattern,ids,expected",
+    [
+        ("{}", ("foo",), "foo"),
+        ("{}", (25,), "25"),
+        ("bla/{}", ("foo",), "bla/foo"),
+        ("raster-{}/{}", (25, "foo"), "raster-25/foo"),
+    ],
+)
+def test_to_key(pattern, ids, expected):
+    mapper = KeyMapper(pattern=pattern)
+    assert mapper.to_key(*ids) == expected
+
+
+@pytest.mark.parametrize(
+    "pattern,ids,expected",
+    [
+        ("{}", (), ""),
+        ("bla/{}", (), "bla/"),
+        ("raster-{}/{}", (25,), "raster-25/"),
+    ],
+)
+def test_to_key_prefix(pattern, ids, expected):
+    mapper = KeyMapper(pattern=pattern)
+    assert mapper.to_key_prefix(*ids) == expected
+
+
+@pytest.mark.parametrize(
+    "pattern,expected,key",
+    [
+        ("{}", ("foo",), "foo"),
+        ("{}", (25,), "25"),
+        ("bla/{}", ("foo",), "bla/foo"),
+        ("raster-{}/{}", (25, "foo"), "raster-25/foo"),
+    ],
+)
+def test_from_key(pattern, expected, key):
+    mapper = KeyMapper(pattern=pattern)
+    assert mapper.from_key(key) == expected
+
+
+@pytest.mark.parametrize("pattern", ["", "/{}", "{}-bla", "{a}/{}"])
+def test_validate_pattern(pattern):
+    with pytest.raises(ValidationError):
+        KeyMapper(pattern=pattern)
+
+
+@pytest.mark.parametrize(
+    "pattern,names,expected",
+    [
+        ("{}", ("name",), "{name}"),
+        ("raster-{}/{}", ("id", "name"), "raster-{id}/{name}"),
+    ],
+)
+def test_get_named_pattern(pattern, names, expected):
+    mapper = KeyMapper(pattern=pattern)
+    assert mapper.get_named_pattern(*names) == expected