浏览代码

Code overhaul

Casper van der Wel 2 年之前
父节点
当前提交
29eb926641
共有 52 个文件被更改,包括 4206 次插入61 次删除
  1. 30 3
      clean_python/__init__.py
  2. 195 0
      clean_python/async_actor.py
  3. 7 0
      clean_python/attr_dict.py
  4. 80 0
      clean_python/celery_rmq_broker.py
  5. 0 0
      clean_python/child_entity.py
  6. 43 0
      clean_python/context.py
  7. 26 0
      clean_python/domain_event.py
  8. 7 0
      clean_python/domain_service.py
  9. 83 0
      clean_python/dramatiq_task_logger.py
  10. 82 0
      clean_python/error_responses.py
  11. 79 0
      clean_python/exceptions.py
  12. 83 0
      clean_python/fastapi_access_logger.py
  13. 21 0
      clean_python/fluentbit_gateway.py
  14. 125 0
      clean_python/gateway.py
  15. 55 0
      clean_python/internal_gateway.py
  16. 9 0
      clean_python/link.py
  17. 61 0
      clean_python/manage.py
  18. 6 0
      clean_python/now.py
  19. 113 0
      clean_python/oauth2.py
  20. 25 0
      clean_python/pagination.py
  21. 33 0
      clean_python/profilers.py
  22. 84 0
      clean_python/repository.py
  23. 51 0
      clean_python/request_query.py
  24. 205 0
      clean_python/resource.py
  25. 30 0
      clean_python/root_entity.py
  26. 0 47
      clean_python/scripts.py
  27. 177 0
      clean_python/service.py
  28. 30 0
      clean_python/sleep_task.py
  29. 315 0
      clean_python/sql_gateway.py
  30. 129 0
      clean_python/sql_provider.py
  31. 34 0
      clean_python/testing.py
  32. 108 0
      clean_python/tests/test_async_actor.py
  33. 34 0
      clean_python/tests/test_celery_rmq_broker.py
  34. 83 0
      clean_python/tests/test_dramatiq_task_logger.py
  35. 32 0
      clean_python/tests/test_exceptions.py
  36. 152 0
      clean_python/tests/test_fastapi_access_logger.py
  37. 173 0
      clean_python/tests/test_gateway.py
  38. 124 0
      clean_python/tests/test_internal_gateway.py
  39. 95 0
      clean_python/tests/test_manage.py
  40. 125 0
      clean_python/tests/test_oauth2.py
  41. 179 0
      clean_python/tests/test_repository.py
  42. 46 0
      clean_python/tests/test_request_query.py
  43. 132 0
      clean_python/tests/test_resource.py
  44. 73 0
      clean_python/tests/test_root_entity.py
  45. 0 10
      clean_python/tests/test_scripts.py
  46. 47 0
      clean_python/tests/test_service.py
  47. 453 0
      clean_python/tests/test_sql_gateway.py
  48. 65 0
      clean_python/tests/test_value_object.py
  49. 12 0
      clean_python/tmpdir_provider.py
  50. 8 0
      clean_python/value.py
  51. 46 0
      clean_python/value_object.py
  52. 1 1
      pyproject.toml

+ 30 - 3
clean_python/__init__.py

@@ -1,3 +1,30 @@
-# fmt: off
-__version__ = "0.1.dev0"
-# fmt: on
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from .async_actor import *  # NOQA
+from .attr_dict import AttrDict  # NOQA
+from .celery_rmq_broker import *  # NOQA
+from .context import *  # NOQA
+from .domain_event import *  # NOQA
+from .domain_service import DomainService  # NOQA
+from .dramatiq_task_logger import *  # NOQA
+from .exceptions import *  # NOQA
+from .fastapi_access_logger import *  # NOQA
+from .fluentbit_gateway import FluentbitGateway  # NOQA
+from .gateway import *  # NOQA
+from .internal_gateway import InternalGateway  # NOQA
+from .link import Link  # NOQA
+from .manage import Manage  # NOQA
+from .now import now  # NOQA
+from .oauth2 import *  # NOQA
+from .pagination import *  # NOQA
+from .repository import Repository  # NOQA
+from .request_query import *  # NOQA
+from .resource import *  # NOQA
+from .root_entity import RootEntity  # NOQA
+from .service import Service  # NOQA
+from .sql_gateway import SQLGateway  # NOQA
+from .sql_provider import *  # NOQA
+from .tmpdir_provider import *  # NOQA
+from .value import Value  # NOQA
+from .value_object import ValueObject, ValueObjectWithId  # NOQA

+ 195 - 0
clean_python/async_actor.py

@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+"""Dramatiq configuration"""
+
+import asyncio
+import logging
+import threading
+import time
+from concurrent.futures import TimeoutError
+from typing import Any, Awaitable, Dict, Optional, TypeVar
+
+import dramatiq
+from asgiref.sync import sync_to_async
+from dramatiq.brokers.stub import StubBroker
+from dramatiq.middleware import Interrupt, Middleware
+
+__all__ = ["AsyncActor", "AsyncMiddleware", "async_actor"]
+
+
+logger = logging.getLogger(__name__)
+
+# Default broker (for testing)
+broker = StubBroker()
+broker.run_coroutine = lambda coro: asyncio.run(coro)
+dramatiq.set_broker(broker)
+
+R = TypeVar("R")
+
+
+class EventLoopThread(threading.Thread):
+    """A thread that starts / stops an asyncio event loop.
+
+    The method 'run_coroutine' should be used to run coroutines from a
+    synchronous context.
+    """
+
+    EVENT_LOOP_START_TIMEOUT = 0.1  # seconds to wait for the event loop to start
+
+    loop: Optional[asyncio.AbstractEventLoop] = None
+
+    def __init__(self):
+        super().__init__(target=self._start_event_loop)
+
+    def _start_event_loop(self):
+        """This method should run in the thread"""
+        logger.info("Starting the event loop...")
+
+        self.loop = asyncio.new_event_loop()
+        try:
+            self.loop.run_forever()
+        finally:
+            self.loop.close()
+
+    def _stop_event_loop(self):
+        """This method should run outside of the thread"""
+        if self.loop is not None:
+            logger.info("Stopping the event loop...")
+            self.loop.call_soon_threadsafe(self.loop.stop)
+
+    def run_coroutine(self, coro: Awaitable[R]) -> R:
+        """To be called from outside the thread
+
+        Blocks until the coroutine is finished.
+        """
+        if self.loop is None or not self.loop.is_running():
+            raise RuntimeError("The event loop is not running")
+
+        done = threading.Event()
+
+        async def wrapped_coro() -> R:
+            try:
+                return await coro
+            finally:
+                done.set()
+
+        future = asyncio.run_coroutine_threadsafe(wrapped_coro(), self.loop)
+        try:
+            while True:
+                try:
+                    # Use a timeout to be able to catch asynchronously raised dramatiq
+                    # exceptions (Shutdown and TimeLimitExceeded).
+                    return future.result(timeout=1)
+                except TimeoutError:
+                    continue
+        except Interrupt:
+            self.loop.call_soon_threadsafe(future.cancel)
+            # The future will raise a CancelledError *before* the coro actually
+            # finished cleanup. Wait for the event instead.
+            done.wait()
+            raise
+
+    def start(self, *args, **kwargs):
+        super().start(*args, **kwargs)
+        time.sleep(self.EVENT_LOOP_START_TIMEOUT)
+        if self.loop is None or not self.loop.is_running():
+            logger.exception("The event loop failed to start")
+        logger.info("Event loop is running.")
+
+    def join(self, *args, **kwargs):
+        self._stop_event_loop()
+        return super().join(*args, **kwargs)
+
+
+class AsyncMiddleware(Middleware):
+    """This middleware enables coroutines to be ran as dramatiq a actors.
+
+    At its core, this middleware spins up a dedicated thread ('event_loop_thread'),
+    which may be used to schedule the coroutines on from the worker threads.
+    """
+
+    event_loop_thread: Optional[EventLoopThread] = None
+
+    def run_coroutine(self, coro: Awaitable[R]) -> R:
+        assert self.event_loop_thread is not None
+        return self.event_loop_thread.run_coroutine(coro)
+
+    def before_worker_boot(self, broker, worker):
+        self.event_loop_thread = EventLoopThread()
+        self.event_loop_thread.start()
+
+        broker.run_coroutine = self.run_coroutine
+
+    def after_worker_shutdown(self, broker, worker):
+        assert self.event_loop_thread is not None
+        self.event_loop_thread.join()
+        self.event_loop_thread = None
+
+        delattr(broker, "run_coroutine")
+
+
+class AsyncActor(dramatiq.Actor):
+    """To configure coroutines as a dramatiq actor.
+
+    Requires AsyncMiddleware to be active.
+
+    Example usage:
+
+    >>> @dramatiq.actor(..., actor_class=AsyncActor)
+    ... async def my_task(x):
+    ...     print(x)
+
+    Notes:
+
+    The async functions are scheduled on an event loop that is shared between
+    worker threads. See AsyncMiddleware.
+
+    This is compatible with ShutdownNotifications ("notify_shutdown") and
+    TimeLimit ("time_limit"). Both result in an asyncio.CancelledError raised inside
+    the async function. There is currently no way to tell the two apart.
+    """
+
+    def __init__(self, fn, *args, **kwargs):
+        super().__init__(
+            lambda *args, **kwargs: self.broker.run_coroutine(fn(*args, **kwargs)),
+            *args,
+            **kwargs,
+        )
+
+    @sync_to_async
+    def send_async(self, *args, **kwargs) -> dramatiq.Message[R]:
+        """See dramatiq.actor.Actor.send.
+
+        Sending a message to a broker is potentially blocking, so @sync_to_async is used.
+        """
+        return super().send(*args, **kwargs)
+
+    @sync_to_async
+    def send_async_with_options(
+        self,
+        *,
+        args: tuple = (),
+        kwargs: Optional[Dict[str, Any]] = None,
+        delay: Optional[int] = None,
+        **options,
+    ) -> dramatiq.Message[R]:
+        """See dramatiq.actor.Actor.send_with_options.
+
+        Sending a message to a broker is potentially blocking, so @sync_to_async is used.
+        """
+        return super().send_with_options(
+            args=args, kwargs=kwargs, delay=delay, **options
+        )
+
+
+def async_actor(awaitable=None, **kwargs):
+    kwargs.setdefault("max_retries", 0)
+    if awaitable:
+        return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
+    else:
+
+        def wrapper(awaitable):
+            return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
+
+        return wrapper

+ 7 - 0
clean_python/attr_dict.py

@@ -0,0 +1,7 @@
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.__dict__ = self
+
+    def dict(self):
+        return self

+ 80 - 0
clean_python/celery_rmq_broker.py

@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import json
+import uuid
+from typing import Optional
+
+import pika
+from asgiref.sync import sync_to_async
+from pydantic import AnyUrl
+
+from .gateway import Gateway, Json
+from .value_object import ValueObject
+
+__all__ = ["CeleryRmqBroker"]
+
+
+class CeleryHeaders(ValueObject):
+    lang: str = "py"
+    task: str
+    id: uuid.UUID
+    root_id: uuid.UUID
+    parent_id: Optional[uuid.UUID] = None
+    group: Optional[uuid.UUID] = None
+    argsrepr: Optional[str] = None
+    kwargsrepr: Optional[str] = None
+    origin: Optional[str] = None
+
+    def json_dict(self):
+        return json.loads(self.json())
+
+
+class CeleryRmqBroker(Gateway):
+    def __init__(
+        self, broker_url: AnyUrl, queue: str, origin: str, declare_queue: bool = False
+    ):
+        self._parameters = pika.URLParameters(broker_url)
+        self._queue = queue
+        self._origin = origin
+        self._declare_queue = declare_queue
+
+    @sync_to_async
+    def add(self, item: Json) -> Json:
+        task = item["task"]
+        args = list(item.get("args") or [])
+        kwargs = dict(item.get("kwargs") or {})
+
+        task_id = uuid.uuid4()
+        header = CeleryHeaders(
+            task=task,
+            id=task_id,
+            root_id=task_id,
+            argsrepr=json.dumps(args),
+            kwargsrepr=json.dumps(kwargs),
+            origin=self._origin,
+        )
+        body = json.dumps((args, kwargs, None))
+
+        with pika.BlockingConnection(self._parameters) as connection:
+            channel = connection.channel()
+
+            if self._declare_queue:
+                channel.queue_declare(queue=self._queue)
+            else:
+                pass  # Configured by Lizard
+
+            properties = pika.BasicProperties(
+                correlation_id=str(task_id),
+                content_type="application/json",
+                content_encoding="utf-8",
+                headers=header.json_dict(),
+            )
+            channel.basic_publish(
+                exchange="",
+                routing_key=self._queue,
+                body=body,
+                properties=properties,
+            )
+
+        return item

+ 0 - 0
clean_python/tests/__init__.py → clean_python/child_entity.py


+ 43 - 0
clean_python/context.py

@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from contextvars import ContextVar
+
+from fastapi import Request
+
+__all__ = ["ctx", "RequestMiddleware"]
+
+
+class Context:
+    def __init__(self):
+        self._request_value: ContextVar[Request] = ContextVar("request_value")
+
+    @property
+    def request(self) -> Request:
+        return self._request_value.get()
+
+    @request.setter
+    def request(self, value: Request) -> None:
+        self._request_value.set(value)
+
+
+ctx = Context()
+
+
+class RequestMiddleware:
+    """Save the current request in a context variable.
+
+    We were experiencing database connections piling up until PostgreSQL's
+    max_connections was hit, which has to do with BaseHTTPMiddleware not
+    interacting properly with context variables. For more details, see:
+    https://github.com/tiangolo/fastapi/issues/4719. Writing this
+    middleware as generic ASGI middleware fixes the problem.
+    """
+
+    def __init__(self, app):
+        self.app = app
+
+    async def __call__(self, scope, receive, send):
+        if scope["type"] == "http":
+            ctx.request = Request(scope, receive)
+        await self.app(scope, receive, send)

+ 26 - 0
clean_python/domain_event.py

@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from abc import ABC
+from typing import Awaitable, Callable, TypeVar
+
+import blinker
+
+__all__ = ["DomainEvent"]
+
+
+TDomainEvent = TypeVar("TDomainEvent", bound="DomainEvent")
+TEventHandler = Callable[[TDomainEvent], Awaitable[None]]
+
+
+class DomainEvent(ABC):
+    @classmethod
+    def _signal(cls) -> blinker.Signal:
+        return blinker.signal(cls.__name__)
+
+    @classmethod
+    def register_handler(cls, receiver: TEventHandler) -> TEventHandler:
+        return cls._signal().connect(receiver)
+
+    async def send_async(self) -> None:
+        await self._signal().send_async(self)

+ 7 - 0
clean_python/domain_service.py

@@ -0,0 +1,7 @@
+from pydantic import BaseModel
+
+
+class DomainService(BaseModel):
+    class Config:
+        allow_mutation = False
+        arbitrary_types_allowed = True

+ 83 - 0
clean_python/dramatiq_task_logger.py

@@ -0,0 +1,83 @@
+import os
+import threading
+import time
+from typing import Optional
+
+import inject
+from dramatiq import get_encoder, Middleware
+from dramatiq.errors import RateLimitExceeded, Retry
+from dramatiq.message import Message
+from dramatiq.middleware import SkipMessage
+
+from .fluentbit_gateway import FluentbitGateway
+from .gateway import Gateway
+
+__all__ = ["AsyncLoggingMiddleware"]
+
+
+class AsyncLoggingMiddleware(Middleware):
+    def __init__(self, **kwargs):
+        self.logger = DramatiqTaskLogger(**kwargs)
+
+    def before_process_message(self, broker, message):
+        broker.run_coroutine(self.logger.start())
+
+    def after_skip_message(self, broker, message):
+        broker.run_coroutine(self.logger.stop(message, None, SkipMessage()))
+
+    def after_process_message(self, broker, message, *, result=None, exception=None):
+        broker.run_coroutine(self.logger.stop(message, result, exception))
+
+
+class DramatiqTaskLogger:
+    local = threading.local()
+
+    def __init__(
+        self,
+        hostname: str,
+        gateway_override: Optional[Gateway] = None,
+    ):
+        self.origin = f"{hostname}-{os.getpid()}"
+        self.gateway_override = gateway_override
+
+    @property
+    def gateway(self):
+        return self.gateway_override or inject.instance(FluentbitGateway)
+
+    @property
+    def encoder(self):
+        return get_encoder()
+
+    async def start(self):
+        self.local.start_time = time.time()
+
+    async def stop(self, message: Message, result=None, exception=None):
+        if exception is None:
+            state = "SUCCESS"
+        elif isinstance(exception, Retry):
+            state = "RETRY"
+        elif isinstance(exception, SkipMessage):
+            state = "EXPIRED"
+        elif isinstance(exception, RateLimitExceeded):
+            state = "TERMINATED"
+        else:
+            state = "FAILURE"
+
+        try:
+            duration = time.time() - self.local.start_time
+        except AttributeError:
+            duration = 0
+
+        log_dict = {
+            "tag_suffix": "task_log",
+            "task_id": message.message_id,
+            "name": message.actor_name,
+            "state": state,
+            "duration": duration,
+            "retries": message.options.get("retries", 0),
+            "origin": self.origin,
+            "argsrepr": self.encoder.encode(message.args),
+            "kwargsrepr": self.encoder.encode(message.kwargs),
+            "result": result,
+        }
+        return await self.gateway.add(log_dict)

+ 82 - 0
clean_python/error_responses.py

@@ -0,0 +1,82 @@
+from typing import List, Union
+
+from fastapi.encoders import jsonable_encoder
+from fastapi.requests import Request
+from fastapi.responses import JSONResponse
+from starlette import status
+
+from .exceptions import (
+    BadRequest,
+    Conflict,
+    DoesNotExist,
+    PermissionDenied,
+    Unauthorized,
+)
+from .value_object import ValueObject
+
+__all__ = [
+    "ValidationErrorResponse",
+    "DefaultErrorResponse",
+    "not_found_handler",
+    "conflict_handler",
+    "validation_error_handler",
+    "not_implemented_handler",
+    "permission_denied_handler",
+    "unauthorized_handler",
+]
+
+
+class ValidationErrorEntry(ValueObject):
+    loc: List[Union[str, int]]
+    msg: str
+    type: str
+
+
+class ValidationErrorResponse(ValueObject):
+    detail: List[ValidationErrorEntry]
+
+
+class DefaultErrorResponse(ValueObject):
+    message: str
+
+
+async def not_found_handler(request: Request, exc: DoesNotExist):
+    return JSONResponse(
+        status_code=status.HTTP_404_NOT_FOUND,
+        content={"message": f"Could not find {exc.name} with id={exc.id}"},
+    )
+
+
+async def conflict_handler(request: Request, exc: Conflict):
+    return JSONResponse(
+        status_code=status.HTTP_409_CONFLICT,
+        content={"message": str(exc)},
+    )
+
+
+async def validation_error_handler(request: Request, exc: BadRequest):
+    return JSONResponse(
+        status_code=status.HTTP_400_BAD_REQUEST,
+        content=jsonable_encoder({"detail": exc.errors()}),
+    )
+
+
+async def not_implemented_handler(request: Request, exc: NotImplementedError):
+    return JSONResponse(
+        status_code=status.HTTP_501_NOT_IMPLEMENTED,
+        content={"message": str(exc)},
+    )
+
+
+async def unauthorized_handler(request: Request, exc: Unauthorized):
+    return JSONResponse(
+        status_code=status.HTTP_401_UNAUTHORIZED,
+        content={"message": "Unauthorized"},
+    )
+
+
+async def permission_denied_handler(request: Request, exc: PermissionDenied):
+    return JSONResponse(
+        status_code=status.HTTP_403_FORBIDDEN,
+        content={"message": "Permission denied"},
+    )

+ 79 - 0
clean_python/exceptions.py

@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from typing import Any, Optional, Union
+
+from pydantic import create_model, ValidationError
+from pydantic.error_wrappers import ErrorWrapper
+
+__all__ = [
+    "AlreadyExists",
+    "Conflict",
+    "DoesNotExist",
+    "PermissionDenied",
+    "PreconditionFailed",
+    "BadRequest",
+    "Unauthorized",
+    "BadRequest",
+]
+
+
+class DoesNotExist(Exception):
+    def __init__(self, name: str, id: Optional[int] = None):
+        super().__init__()
+        self.name = name
+        self.id = id
+
+    def __str__(self):
+        if self.id:
+            return f"does not exist: {self.name} with id={self.id}"
+        else:
+            return f"does not exist: {self.name}"
+
+
+class Conflict(Exception):
+    def __init__(self, msg: Optional[str] = None):
+        super().__init__(msg)
+
+
+class AlreadyExists(Conflict):
+    def __init__(self, id: Optional[int] = None):
+        super().__init__(f"record with id={id} already exists")
+
+
+class PreconditionFailed(Exception):
+    def __init__(self, msg: str = "precondition failed", obj: Any = None):
+        super().__init__(msg)
+        self.obj = obj
+
+
+# pydantic.ValidationError needs some model; for us it doesn't matter
+# We do it the same way as FastAPI does it.
+request_model = create_model("Request")
+
+
+class BadRequest(ValidationError):
+    def __init__(self, err_or_msg: Union[ValidationError, str]):
+        if isinstance(err_or_msg, ValidationError):
+            errors = err_or_msg.raw_errors
+        else:
+            errors = [ErrorWrapper(ValueError(err_or_msg), "*")]
+        super().__init__(errors, request_model)
+
+    def __str__(self) -> str:
+        errors = self.errors()
+        if len(errors) == 1:
+            error = errors[0]
+            loc = "'" + ",".join([str(x) for x in error["loc"]]) + "' "
+            if loc == "'*' ":
+                loc = ""
+            return f"validation error: {loc}{error['msg']}"
+        return super().__str__()
+
+
+class Unauthorized(Exception):
+    pass
+
+
+class PermissionDenied(Exception):
+    pass

+ 83 - 0
clean_python/fastapi_access_logger.py

@@ -0,0 +1,83 @@
+import os
+import time
+from datetime import datetime
+from typing import Awaitable, Callable, Optional
+
+import inject
+from starlette.background import BackgroundTasks
+from starlette.requests import Request
+from starlette.responses import Response
+
+from .fluentbit_gateway import FluentbitGateway
+from .gateway import Gateway
+
+__all__ = ["FastAPIAccessLogger"]
+
+
+class FastAPIAccessLogger:
+    def __init__(self, hostname: str, gateway_override: Optional[Gateway] = None):
+        self.origin = f"{hostname}-{os.getpid()}"
+        self.gateway_override = gateway_override
+
+    @property
+    def gateway(self) -> Gateway:
+        return self.gateway_override or inject.instance(FluentbitGateway)
+
+    async def __call__(
+        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
+    ) -> Response:
+        time_received = time.time()
+        response = await call_next(request)
+        request_time = time.time() - time_received
+
+        # Instead of logging directly, set it as background task so that it is
+        # executed after the response. See https://www.starlette.io/background/.
+        if response.background is None:
+            response.background = BackgroundTasks()
+        response.background.add_task(
+            log_access, self.gateway, request, response, time_received, request_time
+        )
+        return response
+
+
+def fmt_timestamp(timestamp: float) -> str:
+    return datetime.utcfromtimestamp(timestamp).isoformat() + "Z"
+
+
+async def log_access(
+    gateway: Gateway,
+    request: Request,
+    response: Response,
+    time_received: float,
+    request_time: float,
+) -> None:
+    """
+    Create a dictionary with logging data.
+    """
+    try:
+        content_length = int(response.headers.get("content-length"))
+    except (TypeError, ValueError):
+        content_length = None
+
+    try:
+        view_name = request.scope["route"].name
+    except KeyError:
+        view_name = None
+
+    item = {
+        "tag_suffix": "access_log",
+        "remote_address": getattr(request.client, "host", None),
+        "method": request.method,
+        "path": request.url.path,
+        "portal": request.url.netloc,
+        "referer": request.headers.get("referer"),
+        "user_agent": request.headers.get("user-agent"),
+        "query_params": request.url.query,
+        "view_name": view_name,
+        "status": response.status_code,
+        "content_type": response.headers.get("content-type"),
+        "content_length": content_length,
+        "time": fmt_timestamp(time_received),
+        "request_time": request_time,
+    }
+    await gateway.add(item)

+ 21 - 0
clean_python/fluentbit_gateway.py

@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from typing import Any, Dict
+
+from asgiref.sync import sync_to_async
+from fluent.sender import FluentSender
+
+from .gateway import Gateway
+
+Json = Dict[str, Any]
+
+
+class FluentbitGateway(Gateway):
+    def __init__(self, tag: str, host: str, port: int):
+        self._sender = FluentSender(tag, host=host, port=port)
+
+    @sync_to_async
+    def add(self, item: Json) -> Json:
+        self._sender.emit(item.pop("tag_suffix", ""), item)
+        return item

+ 125 - 0
clean_python/gateway.py

@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from copy import deepcopy
+from datetime import datetime
+from typing import Any, Callable, Dict, List, Optional
+
+from .exceptions import AlreadyExists, Conflict, DoesNotExist
+from .pagination import PageOptions
+from .value_object import ValueObject
+
+__all__ = ["Gateway", "Json", "Filter", "InMemoryGateway"]
+Json = Dict[str, Any]
+
+
+class Filter(ValueObject):
+    field: str
+    values: List[Any]
+
+
+class Gateway:
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[Json]:
+        raise NotImplementedError()
+
+    async def count(self, filters: List[Filter]) -> int:
+        return len(await self.filter(filters, params=None))
+
+    async def exists(self, filters: List[Filter]) -> bool:
+        return len(await self.filter(filters, params=PageOptions(limit=1))) > 0
+
+    async def get(self, id: int) -> Optional[Json]:
+        result = await self.filter([Filter(field="id", values=[id])], params=None)
+        return result[0] if result else None
+
+    async def add(self, item: Json) -> Json:
+        raise NotImplementedError()
+
+    async def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        raise NotImplementedError()
+
+    async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
+        existing = await self.get(id)
+        if existing is None:
+            raise DoesNotExist("record", id)
+        return await self.update(
+            func(existing), if_unmodified_since=existing["updated_at"]
+        )
+
+    async def upsert(self, item: Json) -> Json:
+        try:
+            return await self.update(item)
+        except DoesNotExist:
+            return await self.add(item)
+
+    async def remove(self, id: int) -> bool:
+        raise NotImplementedError()
+
+
+class InMemoryGateway(Gateway):
+    """For testing purposes"""
+
+    def __init__(self, data: List[Json]):
+        self.data = {x["id"]: deepcopy(x) for x in data}
+
+    def _get_next_id(self) -> int:
+        if len(self.data) == 0:
+            return 1
+        else:
+            return max(self.data) + 1
+
+    def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:
+        objs = sorted(
+            objs,
+            key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),
+            reverse=not params.ascending,
+        )
+        return objs[params.offset : params.offset + params.limit]
+
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[Json]:
+        result = []
+        for x in self.data.values():
+            for filter in filters:
+                if x.get(filter.field) not in filter.values:
+                    break
+            else:
+                result.append(deepcopy(x))
+        if params is not None:
+            result = self._paginate(result, params)
+        return result
+
+    async def add(self, item: Json) -> Json:
+        item = item.copy()
+        id_ = item.pop("id", None)
+        # autoincrement (like SQL does)
+        if id_ is None:
+            id_ = self._get_next_id()
+        elif id_ in self.data:
+            raise AlreadyExists(id_)
+
+        self.data[id_] = {"id": id_, **item}
+        return deepcopy(self.data[id_])
+
+    async def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        _id = item.get("id")
+        if _id is None or _id not in self.data:
+            raise DoesNotExist("item", _id)
+        existing = self.data[_id]
+        if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:
+            raise Conflict()
+        existing.update(item)
+        return deepcopy(existing)
+
+    async def remove(self, id: int) -> bool:
+        if id not in self.data:
+            return False
+        del self.data[id]
+        return True

+ 55 - 0
clean_python/internal_gateway.py

@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+from abc import abstractmethod, abstractproperty
+from typing import Generic, List, Optional, TypeVar
+
+from .exceptions import BadRequest, DoesNotExist
+from .gateway import Filter
+from .manage import Manage
+from .pagination import PageOptions
+from .root_entity import RootEntity
+from .value_object import ValueObject
+
+E = TypeVar("E", bound=RootEntity)  # External
+T = TypeVar("T", bound=ValueObject)  # Internal
+
+
+# don't subclass Gateway; Gateway makes Json objects
+class InternalGateway(Generic[E, T]):
+    @abstractproperty
+    def manage(self) -> Manage[E]:
+        raise NotImplementedError()
+
+    @abstractmethod
+    def _map(self, obj: E) -> T:
+        raise NotImplementedError()
+
+    async def get(self, id: int) -> Optional[T]:
+        try:
+            result = await self.manage.retrieve(id)
+        except DoesNotExist:
+            return None
+        else:
+            return self._map(result)
+
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[T]:
+        page = await self.manage.filter(filters, params)
+        return [self._map(x) for x in page.items]
+
+    async def add(self, item: T) -> T:
+        try:
+            created = await self.manage.create(item.dict())
+        except BadRequest as e:
+            raise ValueError(e)
+        return self._map(created)
+
+    async def remove(self, id) -> bool:
+        return await self.manage.destroy(id)
+
+    async def count(self, filters: List[Filter]) -> int:
+        return await self.manage.count(filters)
+
+    async def exists(self, filters: List[Filter]) -> bool:
+        return await self.manage.exists(filters)

+ 9 - 0
clean_python/link.py

@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from pydantic import AnyHttpUrl
+from typing_extensions import TypedDict
+
+
+class Link(TypedDict):
+    href: AnyHttpUrl

+ 61 - 0
clean_python/manage.py

@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from typing import Any, Generic, List, Optional, Type, TypeVar
+
+from .gateway import Filter, Json
+from .pagination import Page, PageOptions
+from .repository import Repository
+from .root_entity import RootEntity
+
+T = TypeVar("T", bound=RootEntity)
+
+
+__all__ = ["Manage"]
+
+
+class Manage(Generic[T]):
+    repo: Repository[T]
+    entity: Type[T]
+
+    def __init__(self, repo: Optional[Repository[T]] = None):
+        assert repo is not None
+        self.repo = repo
+
+    def __init_subclass__(cls) -> None:
+        (base,) = cls.__orig_bases__  # type: ignore
+        (entity,) = base.__args__
+        assert issubclass(entity, RootEntity)
+        super().__init_subclass__()
+        cls.entity = entity
+
+    async def retrieve(self, id: int) -> T:
+        return await self.repo.get(id)
+
+    async def create(self, values: Json) -> T:
+        return await self.repo.add(values)
+
+    async def update(self, id: int, values: Json) -> T:
+        return await self.repo.update(id, values)
+
+    async def destroy(self, id: int) -> bool:
+        return await self.repo.remove(id)
+
+    async def list(self, params: Optional[PageOptions] = None) -> Page[T]:
+        return await self.repo.all(params)
+
+    async def by(
+        self, key: str, value: Any, params: Optional[PageOptions] = None
+    ) -> Page[T]:
+        return await self.repo.by(key, value, params=params)
+
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> Page[T]:
+        return await self.repo.filter(filters, params=params)
+
+    async def count(self, filters: List[Filter]) -> int:
+        return await self.repo.count(filters)
+
+    async def exists(self, filters: List[Filter]) -> bool:
+        return await self.repo.exists(filters)

+ 6 - 0
clean_python/now.py

@@ -0,0 +1,6 @@
+from datetime import datetime, timezone
+
+
+def now():
+    # this function is there so that we can mock it in tests
+    return datetime.now(timezone.utc)

+ 113 - 0
clean_python/oauth2.py

@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import logging
+from typing import Dict, List
+
+import jwt
+from jwt import PyJWKClient
+from jwt.exceptions import PyJWTError
+from pydantic import AnyHttpUrl, BaseModel
+
+from .exceptions import PermissionDenied, Unauthorized
+
+__all__ = ["OAuth2Settings", "OAuth2AccessTokenVerifier"]
+
+logger = logging.getLogger(__name__)
+
+
+class OAuth2Settings(BaseModel):
+    client_id: str
+    issuer: str
+    resource_server_id: str
+    token_url: AnyHttpUrl
+    authorization_url: AnyHttpUrl
+    algorithms: List[str] = ["RS256"]
+    admin_users: List[str]
+
+
+class OAuth2AccessTokenVerifier:
+    """A class for verifying OAuth2 Access Tokens from AWS Cognito
+
+    The verification steps followed are documented here:
+
+    https://docs.aws.amazon.com/cognito/latest/developerguide/amazon- ⏎
+    cognito-user-pools-using-tokens-verifying-a-jwt.html
+    """
+
+    # allow 2 minutes leeway for verifying token expiry:
+    LEEWAY = 120
+
+    def __init__(
+        self,
+        scope: str,
+        issuer: str,
+        resource_server_id: str,
+        algorithms: List[str],
+        admin_users: List[str],
+    ):
+        self.scope = scope
+        self.issuer = issuer
+        self.algorithms = algorithms
+        self.resource_server_id = resource_server_id
+        self.admin_users = admin_users
+        self.jwk_client = PyJWKClient(f"{issuer}/.well-known/jwks.json")
+
+    def __call__(self, token: str) -> Dict:
+        # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
+        # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
+        try:
+            key = self.get_key(token)  # JSON Web Key
+        except PyJWTError as e:
+            logger.info("Token is invalid: %s", e)
+            raise Unauthorized()
+        # Step 2: Validate the JWT signature and standard claims
+        try:
+            claims = jwt.decode(
+                token,
+                key.key,
+                algorithms=self.algorithms,
+                issuer=self.issuer,
+                leeway=self.LEEWAY,
+                options={
+                    "require": ["exp", "iss", "sub", "scope", "token_use"],
+                },
+            )
+        except PyJWTError as e:
+            logger.info("Token is invalid: %s", e)
+            raise Unauthorized()
+        # Step 3: Verify additional claims. At this point, we have passed
+        # verification, so unverified claims may be used safely.
+        self.verify_token_use(claims)
+        self.verify_scope(claims)
+        # Step 4: Authorization: we currently work with a hardcoded
+        # list of users ('sub' claims)
+        self.authorize(claims)
+        return claims
+
+    def get_key(self, token) -> jwt.PyJWK:
+        """Return the JSON Web KEY (JWK) corresponding to kid."""
+        return self.jwk_client.get_signing_key_from_jwt(token)
+
+    def verify_token_use(self, claims):
+        """Check the token_use claim."""
+        if claims["token_use"] != "access":
+            logger.info("Token has invalid token_use claim: %s", claims["token_use"])
+            raise Unauthorized()
+
+    def verify_scope(self, claims):
+        """Check scope claim.
+
+        Cognito includes the resource server id inside the scope, like this:
+
+           raster.lizard.net/*.readwrite
+        """
+        if f"{self.resource_server_id}{self.scope}" not in claims["scope"].split(" "):
+            logger.info("Token has invalid scope claim: %s", claims["scope"])
+            raise Unauthorized()
+
+    def authorize(self, claims):
+        """The subject (sub) claim should be in a hard-coded whitelist."""
+        if claims.get("sub") not in self.admin_users:
+            logger.info("User with sub %s is not authorized", claims.get("sub"))
+            raise PermissionDenied()

+ 25 - 0
clean_python/pagination.py

@@ -0,0 +1,25 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from typing import Generic, Optional, Sequence, TypeVar
+
+from pydantic import BaseModel
+from pydantic.generics import GenericModel
+
+__all__ = ["Page", "PageOptions"]
+
+T = TypeVar("T")
+
+
+class PageOptions(BaseModel):
+    limit: int
+    offset: int = 0
+    order_by: str = "id"
+    ascending: bool = True
+
+
+class Page(GenericModel, Generic[T]):
+    total: int
+    items: Sequence[T]
+    limit: Optional[int] = None
+    offset: Optional[int] = None

+ 33 - 0
clean_python/profilers.py

@@ -0,0 +1,33 @@
+from pathlib import Path
+
+import dramatiq
+import yappi
+
+__all__ = ["ProfilerMiddleware"]
+
+PROFILE_DIR = "var"
+
+
+class ProfilerMiddleware(dramatiq.Middleware):
+    """For usage with dramatiq (single-threaded only)"""
+
+    def __init__(self, profile_dir: Path):
+        profile_dir.mkdir(exist_ok=True)
+        self.profile_dir = profile_dir
+
+    def before_process_message(self, broker, message):
+        yappi.set_clock_type("wall")
+        yappi.start()
+
+    def after_process_message(
+        self, broker, message: dramatiq.Message, *, result=None, exception=None
+    ):
+        yappi.stop()
+
+        stats = yappi.convert2pstats(yappi.get_func_stats())
+
+        stats.dump_stats(
+            self.profile_dir / f"{message.actor_name}-{message.message_id}.pstats"
+        )
+
+        yappi.clear_stats()

+ 84 - 0
clean_python/repository.py

@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from typing import Any, Generic, List, Optional, Type, TypeVar, Union
+
+from .exceptions import DoesNotExist
+from .gateway import Filter, Gateway, Json
+from .pagination import Page, PageOptions
+from .root_entity import RootEntity
+
+T = TypeVar("T", bound=RootEntity)
+
+
+class Repository(Generic[T]):
+    entity: Type[T]
+
+    def __init__(self, gateway: Gateway):
+        self.gateway = gateway
+
+    def __init_subclass__(cls) -> None:
+        (base,) = cls.__orig_bases__  # type: ignore
+        (entity,) = base.__args__
+        assert issubclass(entity, RootEntity)
+        super().__init_subclass__()
+        cls.entity = entity
+
+    async def all(self, params: Optional[PageOptions] = None) -> Page[T]:
+        return await self.filter([], params=params)
+
+    async def by(
+        self, key: str, value: Any, params: Optional[PageOptions] = None
+    ) -> Page[T]:
+        return await self.filter([Filter(field=key, values=[value])], params=params)
+
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> Page[T]:
+        records = await self.gateway.filter(filters, params=params)
+        total = len(records)
+        # when using pagination, we may need to do a count in the db
+        # except in a typical 'first page' situation with few records
+        if params is not None and not (params.offset == 0 and total < params.limit):
+            total = await self.count(filters)
+        return Page(
+            total=total,
+            limit=params.limit if params else None,
+            offset=params.offset if params else None,
+            items=[self.entity(**x) for x in records],
+        )
+
+    async def get(self, id: int) -> T:
+        res = await self.gateway.get(id)
+        if res is None:
+            raise DoesNotExist("object", id)
+        else:
+            return self.entity(**res)
+
+    async def add(self, item: Union[T, Json]) -> T:
+        if isinstance(item, dict):
+            item = self.entity.create(**item)
+        created = await self.gateway.add(item.dict())
+        return self.entity(**created)
+
+    async def update(self, id: int, values: Json) -> T:
+        if not values:
+            return await self.get(id)
+        updated = await self.gateway.update_transactional(
+            id, lambda x: self.entity(**x).update(**values).dict()
+        )
+        return self.entity(**updated)
+
+    async def upsert(self, item: T) -> T:
+        values = item.dict()
+        upserted = await self.gateway.upsert(values)
+        return self.entity(**upserted)
+
+    async def remove(self, id: int) -> bool:
+        return await self.gateway.remove(id)
+
+    async def count(self, filters: List[Filter]) -> int:
+        return await self.gateway.count(filters)
+
+    async def exists(self, filters: List[Filter]) -> bool:
+        return await self.gateway.exists(filters)

+ 51 - 0
clean_python/request_query.py

@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from typing import List
+
+from fastapi import Query
+from pydantic import validator
+
+from .gateway import Filter
+from .pagination import PageOptions
+from .value_object import ValueObject
+
+__all__ = ["RequestQuery"]
+
+
+class RequestQuery(ValueObject):
+    limit: int = Query(50, ge=1, le=100, description="Page size limit")
+    offset: int = Query(0, ge=0, description="Page offset")
+    order_by: str = Query(
+        default="id", enum=["id", "-id"], description="Field to order by"
+    )
+
+    @validator("order_by")
+    def validate_order_by_enum(cls, v):
+        # the 'enum' parameter doesn't actually do anthing in validation
+        # See: https://github.com/tiangolo/fastapi/issues/2910
+        allowed = cls.__fields__["order_by"].field_info.extra["enum"]
+        if v not in allowed:
+            raise ValueError(f"'order_by' must be one of {allowed}")
+        return v
+
+    def as_page_options(self) -> PageOptions:
+        if self.order_by.startswith("-"):
+            order_by = self.order_by[1:]
+            ascending = False
+        else:
+            order_by = self.order_by
+            ascending = True
+        return PageOptions(
+            limit=self.limit, offset=self.offset, order_by=order_by, ascending=ascending
+        )
+
+    def filters(self) -> List[Filter]:
+        result = []
+        for name in self.__fields__:
+            if name in {"limit", "offset", "order_by"}:
+                continue
+            value = getattr(self, name)
+            if value is not None:
+                result.append(Filter(field=name, values=[value]))
+        return result

+ 205 - 0
clean_python/resource.py

@@ -0,0 +1,205 @@
+from enum import Enum
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Sequence, Type
+
+from fastapi.routing import APIRouter
+
+from .value_object import ValueObject
+
+__all__ = [
+    "Resource",
+    "get",
+    "post",
+    "put",
+    "patch",
+    "delete",
+    "APIVersion",
+    "Stability",
+    "v",
+    "clean_resources",
+]
+
+
+class Stability(str, Enum):
+    STABLE = "stable"
+    BETA = "beta"
+    ALPHA = "alpha"
+
+    @property
+    def description(self) -> str:
+        return DESCRIPTIONS[self]
+
+    def decrease(self) -> "Stability":
+        index = STABILITY_ORDER.index(self)
+        if index == 0:
+            raise ValueError(f"Cannot decrease stability of {self}")
+        return STABILITY_ORDER[index - 1]
+
+
+STABILITY_ORDER = [Stability.ALPHA, Stability.BETA, Stability.STABLE]
+DESCRIPTIONS = {
+    Stability.STABLE: "The stable API version.",
+    Stability.BETA: "Backwards incompatible changes will be announced beforehand.",
+    Stability.ALPHA: "May get backwards incompatible changes without warning.",
+}
+
+
+class APIVersion(ValueObject):
+    version: int
+    stability: Stability
+
+    @property
+    def prefix(self) -> str:
+        result = f"v{self.version}"
+        if self.stability is not Stability.STABLE:
+            result += f"-{self.stability.value}"
+        return result
+
+    @property
+    def description(self) -> str:
+        return self.stability.description
+
+    def decrease_stability(self) -> "APIVersion":
+        return APIVersion(version=self.version, stability=self.stability.decrease())
+
+
+def http_method(path: str, **route_options):
+    def wrapper(unbound_method: Callable[..., Any]):
+        setattr(
+            unbound_method,
+            "http_method",
+            (path, route_options),
+        )
+        return unbound_method
+
+    return wrapper
+
+
+def v(version: int, stability: str = "stable") -> APIVersion:
+    return APIVersion(version=version, stability=Stability(stability))
+
+
+get = partial(http_method, methods=["GET"])
+post = partial(http_method, methods=["POST"])
+put = partial(http_method, methods=["PUT"])
+patch = partial(http_method, methods=["PATCH"])
+delete = partial(http_method, methods=["DELETE"])
+
+
+class OpenApiTag(ValueObject):
+    name: str
+    description: Optional[str]
+
+
+class Resource:
+    version: APIVersion
+    name: str
+
+    def __init_subclass__(cls, version: APIVersion, name: str = ""):
+        cls.version = version
+        cls.name = name
+        super().__init_subclass__()
+
+    @classmethod
+    def with_version(cls, version: APIVersion) -> Type["Resource"]:
+        class DynamicResource(cls, version=version, name=cls.name):  # type: ignore
+            pass
+
+        DynamicResource.__doc__ = cls.__doc__
+
+        return DynamicResource
+
+    def get_less_stable(self, resources: Dict[APIVersion, "Resource"]) -> "Resource":
+        """Fetch a less stable version of this resource from 'resources'
+
+        If it doesn't exist, create it dynamically.
+        """
+        less_stable_version = self.version.decrease_stability()
+
+        # Fetch the less stable resource; generate it if it does not exist
+        try:
+            less_stable_resource = resources[less_stable_version]
+        except KeyError:
+            less_stable_resource = self.__class__.with_version(less_stable_version)()
+
+        # Validate the less stable version
+        if less_stable_resource.__class__.__bases__ != (self.__class__,):
+            raise RuntimeError(
+                f"{less_stable_resource} should be a direct subclass of {self}"
+            )
+
+        return less_stable_resource
+
+    def _endpoints(self):
+        for attr_name in dir(self):
+            if attr_name.startswith("_"):
+                continue
+            endpoint = getattr(self, attr_name)
+            if not hasattr(endpoint, "http_method"):
+                continue
+            yield endpoint
+
+    def get_openapi_tag(self) -> OpenApiTag:
+        return OpenApiTag(
+            name=self.name,
+            description=self.__class__.__doc__,
+        )
+
+    def get_router(
+        self, version: APIVersion, responses: Optional[Dict[str, Dict[str, Any]]] = None
+    ) -> APIRouter:
+        assert version == self.version
+        router = APIRouter()
+        operation_ids = set()
+        for endpoint in self._endpoints():
+            path, route_options = endpoint.http_method
+            operation_id = endpoint.__name__
+            if operation_id in operation_ids:
+                raise RuntimeError(
+                    "Multiple operations {operation_id} configured in {self}"
+                )
+            operation_ids.add(operation_id)
+            # The 'name' is used for reverse lookups (request.path_for): include the
+            # version prefix so that we can uniquely refer to an operation.
+            name = version.prefix + "/" + endpoint.__name__
+            router.add_api_route(
+                path,
+                endpoint,
+                tags=[self.name],
+                operation_id=endpoint.__name__,
+                name=name,
+                responses=responses,
+                **route_options,
+            )
+        return router
+
+
+def clean_resources_same_name(resources: List[Resource]) -> List[Resource]:
+    dct = {x.version: x for x in resources}
+    if len(dct) != len(resources):
+        raise RuntimeError(
+            f"Resource with name {resources[0].name} "
+            f"is defined multiple times with the same version."
+        )
+    for stability in [Stability.STABLE, Stability.BETA]:
+        tmp_resources = {k: v for (k, v) in dct.items() if k.stability is stability}
+        for version, resource in tmp_resources.items():
+            dct[version.decrease_stability()] = resource.get_less_stable(dct)
+    return list(dct.values())
+
+
+def clean_resources(resources: Sequence[Resource]) -> List[Resource]:
+    """Ensure that resources are consistent:
+
+    - ordered by name
+    - (tag, version) combinations should be unique
+    - for stable resources, beta & alpha are autocreated if needed
+    - for beta resources, alpha is autocreated if needed
+    """
+    result = []
+    names = {x.name for x in resources}
+    for name in sorted(names):
+        result.extend(
+            clean_resources_same_name([x for x in resources if x.name == name])
+        )
+    return result

+ 30 - 0
clean_python/root_entity.py

@@ -0,0 +1,30 @@
+from datetime import datetime
+from typing import Optional, Type, TypeVar
+
+from .exceptions import BadRequest
+from .now import now
+from .value_object import ValueObject
+
+T = TypeVar("T", bound="RootEntity")
+
+
+class RootEntity(ValueObject):
+    id: Optional[int] = None
+    created_at: datetime
+    updated_at: datetime
+
+    @classmethod
+    def create(cls: Type[T], **values) -> T:
+        values.setdefault("created_at", now())
+        values.setdefault("updated_at", values["created_at"])
+        return super(RootEntity, cls).create(**values)
+
+    def update(self: T, **values) -> T:
+        if "id" in values and self.id is not None and values["id"] != self.id:
+            raise BadRequest("Cannot change the id of an entity")
+        values.setdefault("updated_at", now())
+        return super().update(**values)
+
+    def __hash__(self):
+        assert self.id is not None
+        return hash(self.__class__) + hash(self.id)

+ 0 - 47
clean_python/scripts.py

@@ -1,47 +0,0 @@
-"""TODO Docstring, used in the command line help text."""
-import argparse
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-def get_parser():
-    """Return argument parser."""
-    parser = argparse.ArgumentParser(description=__doc__)
-    parser.add_argument(
-        "-v",
-        "--verbose",
-        action="store_true",
-        dest="verbose",
-        default=False,
-        help="Verbose output",
-    )
-    # add arguments here
-    # parser.add_argument(
-    #     'path',
-    #     metavar='FILE',
-    # )
-    return parser
-
-
-def main():  # pragma: no cover
-    """Call main command with args from parser.
-
-    This method is called when you run 'bin/run-clean-python',
-    this is configured in 'setup.py'. Adjust when needed. You can have multiple
-    main scripts.
-
-    """
-    options = get_parser().parse_args()
-    if options.verbose:
-        log_level = logging.DEBUG
-    else:
-        log_level = logging.INFO
-    logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s")
-
-    try:
-        print("Call some function from another file here")
-        # ^^^ TODO: pass in options.xyz where needed.
-    except:  # noqa: E722
-        logger.exception("An exception has occurred.")
-        return 1

+ 177 - 0
clean_python/service.py

@@ -0,0 +1,177 @@
+import logging
+from typing import Any, Callable, Dict, List, Optional, Set
+
+from asgiref.sync import sync_to_async
+from fastapi import Depends, FastAPI, Request
+from fastapi.exceptions import HTTPException, RequestValidationError
+from fastapi.security import OAuth2AuthorizationCodeBearer
+from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
+from starlette.types import ASGIApp
+
+from .context import RequestMiddleware
+from .error_responses import (
+    BadRequest,
+    conflict_handler,
+    DefaultErrorResponse,
+    not_found_handler,
+    not_implemented_handler,
+    permission_denied_handler,
+    unauthorized_handler,
+    validation_error_handler,
+    ValidationErrorResponse,
+)
+from .exceptions import Conflict, DoesNotExist, PermissionDenied, Unauthorized
+from .fastapi_access_logger import FastAPIAccessLogger
+from .gateway import Gateway
+from .oauth2 import OAuth2AccessTokenVerifier, OAuth2Settings
+from .resource import APIVersion, clean_resources, Resource
+
+logger = logging.getLogger(__name__)
+
+
+class OAuth2Dependable(OAuth2AuthorizationCodeBearer):
+    """A fastapi 'dependable' configuring OAuth2.
+
+    This does two things:
+    - Verify the token in each request
+    - (through FastAPI magic) add the scheme to the OpenAPI spec
+    """
+
+    def __init__(self, scope, settings: OAuth2Settings):
+        self.verifier = sync_to_async(
+            OAuth2AccessTokenVerifier(
+                scope,
+                issuer=settings.issuer,
+                resource_server_id=settings.resource_server_id,
+                algorithms=settings.algorithms,
+                admin_users=settings.admin_users,
+            ),
+            thread_sensitive=False,
+        )
+        super().__init__(
+            authorizationUrl=settings.authorization_url,
+            tokenUrl=settings.token_url,
+            scopes={
+                f"{settings.resource_server_id}*:readwrite": "Full read/write access"
+            },
+        )
+
+    async def __call__(self, request: Request) -> None:
+        token = await super().__call__(request)
+        try:
+            await self.verifier(token)
+        except Unauthorized:
+            raise HTTPException(status_code=HTTP_401_UNAUTHORIZED)
+        except PermissionDenied:
+            raise HTTPException(status_code=HTTP_403_FORBIDDEN)
+
+
+def fastapi_oauth_kwargs(auth: Optional[OAuth2Settings]) -> Dict:
+    if auth is None:
+        return {}
+    return {
+        "dependencies": [Depends(OAuth2Dependable(scope="*:readwrite", settings=auth))],
+        "swagger_ui_init_oauth": {
+            "clientId": auth.client_id,
+            "usePkceWithAuthorizationCodeGrant": True,
+        },
+    }
+
+
+async def health_check():
+    """Simple health check route"""
+    return {"health": "OK"}
+
+
+class Service:
+    resources: List[Resource]
+
+    def __init__(self, *args: Resource):
+        self.resources = clean_resources(args)
+
+    @property
+    def versions(self) -> Set[APIVersion]:
+        return set([x.version for x in self.resources])
+
+    def _create_root_app(
+        self,
+        title: str,
+        description: str,
+        hostname: str,
+        on_startup: Optional[List[Callable[[], Any]]] = None,
+        access_logger_gateway: Optional[Gateway] = None,
+    ) -> FastAPI:
+        app = FastAPI(
+            title=title,
+            description=description,
+            on_startup=on_startup,
+            servers=[
+                {"url": f"{x.prefix}", "description": x.description}
+                for x in self.versions
+            ],
+            root_path_in_servers=False,
+        )
+        app.middleware("http")(
+            FastAPIAccessLogger(
+                hostname=hostname, gateway_override=access_logger_gateway
+            )
+        )
+        app.add_middleware(RequestMiddleware)
+        app.get("/health", include_in_schema=False)(health_check)
+        return app
+
+    def _create_versioned_app(self, version: APIVersion, **kwargs) -> FastAPI:
+        resources = [x for x in self.resources if x.version == version]
+        app = FastAPI(
+            version=version.prefix,
+            tags=sorted(
+                [x.get_openapi_tag().dict() for x in resources], key=lambda x: x["name"]
+            ),
+            **kwargs,
+        )
+        for resource in resources:
+            app.include_router(
+                resource.get_router(
+                    version,
+                    responses={
+                        "400": {"model": ValidationErrorResponse},
+                        "default": {"model": DefaultErrorResponse},
+                    },
+                )
+            )
+        app.add_exception_handler(DoesNotExist, not_found_handler)
+        app.add_exception_handler(Conflict, conflict_handler)
+        app.add_exception_handler(RequestValidationError, validation_error_handler)
+        app.add_exception_handler(BadRequest, validation_error_handler)
+        app.add_exception_handler(NotImplementedError, not_implemented_handler)
+        app.add_exception_handler(PermissionDenied, permission_denied_handler)
+        app.add_exception_handler(Unauthorized, unauthorized_handler)
+        return app
+
+    def create_app(
+        self,
+        title: str,
+        description: str,
+        hostname: str,
+        auth: Optional[OAuth2Settings] = None,
+        on_startup: Optional[List[Callable[[], Any]]] = None,
+        access_logger_gateway: Optional[Gateway] = None,
+    ) -> ASGIApp:
+        app = self._create_root_app(
+            title=title,
+            description=description,
+            hostname=hostname,
+            on_startup=on_startup,
+            access_logger_gateway=access_logger_gateway,
+        )
+        kwargs = {
+            "title": title,
+            "description": description,
+            **fastapi_oauth_kwargs(auth),
+        }
+        versioned_apps = {
+            v: self._create_versioned_app(v, **kwargs) for v in self.versions
+        }
+        for v, versioned_app in versioned_apps.items():
+            app.mount("/" + v.prefix, versioned_app)
+        return app

+ 30 - 0
clean_python/sleep_task.py

@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import asyncio
+
+from dramatiq.middleware import SkipMessage
+
+from .async_actor import async_actor
+
+
+@async_actor(
+    retry_when=lambda x, y: isinstance(y, KeyError),
+    max_retries=1,
+)
+async def sleep_task(seconds: int, return_value=None, event="success"):
+    event = event.lower()
+    if event == "success":
+        await asyncio.sleep(int(seconds))
+    elif event == "crash":
+        import ctypes
+
+        ctypes.string_at(0)  # segfault
+    elif event == "skip":
+        raise SkipMessage("skipping")
+    elif event == "retry":
+        raise KeyError("will-retry")
+    else:
+        raise ValueError(f"Unknown event '{event}'")
+
+    return return_value

+ 315 - 0
clean_python/sql_gateway.py

@@ -0,0 +1,315 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+from contextlib import asynccontextmanager
+from datetime import datetime
+from typing import AsyncIterator, Callable, List, Optional, TypeVar
+
+import inject
+from sqlalchemy import asc, delete, desc, func, select, Table, true, update
+from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.sql import Executable
+from sqlalchemy.sql.expression import ColumnElement, false
+
+from .exceptions import AlreadyExists, Conflict, DoesNotExist
+from .gateway import Filter, Gateway, Json
+from .pagination import PageOptions
+from .sql_provider import SQLDatabase, SQLProvider
+
+
+def _is_unique_violation_error_id(e: IntegrityError, id: int):
+    # sqlalchemy wraps the asyncpg error
+    msg = e.orig.args[0]
+    return ("duplicate key value violates unique constraint" in msg) and (
+        f"Key (id)=({id}) already exists." in msg
+    )
+
+
+T = TypeVar("T", bound="SQLGateway")
+
+
+class SQLGateway(Gateway):
+    table: Table
+    nested: bool
+
+    def __init__(
+        self, provider_override: Optional[SQLProvider] = None, nested: bool = False
+    ):
+        self.provider_override = provider_override
+        self.nested = nested
+
+    @property
+    def provider(self):
+        return self.provider_override or inject.instance(SQLDatabase)
+
+    def __init_subclass__(cls, table: Table) -> None:
+        cls.table = table
+        super().__init_subclass__()
+
+    def rows_to_dict(self, rows: List[Json]) -> List[Json]:
+        return rows
+
+    def dict_to_row(self, obj: Json) -> Json:
+        known = {c.key for c in self.table.c}
+        result = {k: obj[k] for k in obj.keys() if k in known}
+        if "id" in result and result["id"] is None:
+            del result["id"]
+        return result
+
+    @asynccontextmanager
+    async def transaction(self: T) -> AsyncIterator[T]:
+        if self.nested:
+            yield self
+        else:
+            async with self.provider.transaction() as provider:
+                yield self.__class__(provider, nested=True)
+
+    async def get_related(self, items: List[Json]) -> None:
+        pass
+
+    async def set_related(self, item: Json, result: Json) -> None:
+        pass
+
+    async def execute(self, query: Executable) -> List[Json]:
+        assert self.nested
+        return self.rows_to_dict(await self.provider.execute(query))
+
+    async def add(self, item: Json) -> Json:
+        query = (
+            insert(self.table).values(**self.dict_to_row(item)).returning(self.table)
+        )
+        async with self.transaction() as transaction:
+            try:
+                (result,) = await transaction.execute(query)
+            except IntegrityError as e:
+                id_ = item.get("id")
+                if id_ is not None and _is_unique_violation_error_id(e, id_):
+                    raise AlreadyExists(id_)
+                raise
+            await transaction.set_related(item, result)
+        return result
+
+    async def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        id_ = item.get("id")
+        if id_ is None:
+            raise DoesNotExist("record", id_)
+        q = self.table.c.id == id_
+        if if_unmodified_since is not None:
+            q &= self.table.c.updated_at == if_unmodified_since
+        query = (
+            update(self.table)
+            .where(q)
+            .values(**self.dict_to_row(item))
+            .returning(self.table)
+        )
+        async with self.transaction() as transaction:
+            result = await transaction.execute(query)
+            if not result:
+                if if_unmodified_since is not None:
+                    # note: the get() is to maybe raise DoesNotExist
+                    if await self.get(id_):
+                        raise Conflict()
+                raise DoesNotExist("record", id_)
+            await transaction.set_related(item, result[0])
+        return result[0]
+
+    async def _select_for_update(self, id: int) -> Json:
+        async with self.transaction() as transaction:
+            result = await transaction.execute(
+                select(self.table).with_for_update().where(self.table.c.id == id),
+            )
+            if not result:
+                raise DoesNotExist("record", id)
+            await transaction.get_related(result)
+        return result[0]
+
+    async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
+        async with self.transaction() as transaction:
+            existing = await transaction._select_for_update(id)
+            updated = func(existing)
+            return await transaction.update(updated)
+
+    async def upsert(self, item: Json) -> Json:
+        if item.get("id") is None:
+            return await self.add(item)
+        values = self.dict_to_row(item)
+        query = (
+            insert(self.table)
+            .values(**values)
+            .on_conflict_do_update(index_elements=["id"], set_=values)
+            .returning(self.table)
+        )
+        async with self.transaction() as transaction:
+            result = await transaction.execute(query)
+            await transaction.set_related(item, result[0])
+        return result[0]
+
+    async def remove(self, id) -> bool:
+        query = (
+            delete(self.table).where(self.table.c.id == id).returning(self.table.c.id)
+        )
+        async with self.transaction() as transaction:
+            result = await transaction.execute(query)
+        return bool(result)
+
+    def _to_sqlalchemy_expression(self, filter: Filter) -> ColumnElement:
+        try:
+            column = getattr(self.table.c, filter.field)
+        except AttributeError:
+            return false()
+        if len(filter.values) == 0:
+            return false()
+        elif len(filter.values) == 1:
+            return column == filter.values[0]
+        else:
+            return column.in_(filter.values)
+
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[Json]:
+        query = select(self.table).where(
+            *[self._to_sqlalchemy_expression(x) for x in filters]
+        )
+        if params is not None:
+            sort = asc(params.order_by) if params.ascending else desc(params.order_by)
+            query = query.order_by(sort).limit(params.limit).offset(params.offset)
+        async with self.transaction() as transaction:
+            result = await transaction.execute(query)
+            await transaction.get_related(result)
+        return result
+
+    async def count(self, filters: List[Filter]) -> int:
+        query = (
+            select(func.count().label("count"))
+            .select_from(self.table)
+            .where(*[self._to_sqlalchemy_expression(x) for x in filters])
+        )
+        async with self.transaction() as transaction:
+            return (await transaction.execute(query))[0]["count"]
+
+    async def exists(self, filters: List[Filter]) -> bool:
+        query = (
+            select(true().label("exists"))
+            .select_from(self.table)
+            .where(*[self._to_sqlalchemy_expression(x) for x in filters])
+            .limit(1)
+        )
+        async with self.transaction() as transaction:
+            return len(await transaction.execute(query)) > 0
+
+    async def _get_related_one_to_many(
+        self,
+        items: List[Json],
+        field_name: str,
+        fk_name: str,
+    ) -> None:
+        """Fetch related objects for `items` and add them inplace.
+
+        The result is `items` having an additional field containing a list of related
+        objects which were retrieved from self in 1 SELECT query.
+
+        Args:
+            items: The items for which to fetch related objects. Changed inplace.
+            field_name: The key in item to put the fetched related objects into.
+            fk_name: The column name on the related object that refers to item["id"]
+
+        Example:
+            Writer has a one-to-many relation to books.
+
+            >>> writers = [{"id": 2, "name": "John Doe"}]
+            >>> _get_related_one_to_many(
+                items=writers,
+                related_gateway=BookSQLGateway,
+                field_name="books",
+                fk_name="writer_id",
+            )
+            >>> writers[0]
+            {
+                "id": 2,
+                "name": "John Doe",
+                "books": [
+                    {
+                        "id": 1",
+                        "title": "How to write an ORM",
+                        "writer_id": 2
+                    }
+                ]
+            }
+        """
+        for x in items:
+            x[field_name] = []
+        item_lut = {x["id"]: x for x in items}
+        related_objs = await self.filter(
+            [Filter(field=fk_name, values=list(item_lut.keys()))]
+        )
+        for related_obj in related_objs:
+            item_lut[related_obj[fk_name]][field_name].append(related_obj)
+
+    async def _set_related_one_to_many(
+        self,
+        item: Json,
+        result: Json,
+        field_name: str,
+        fk_name: str,
+    ) -> None:
+        """Set related objects for `item`
+
+        This method first fetches the current situation and then adds / updates / removes
+        where appropriate.
+
+        Args:
+            item: The item for which to set related objects.
+            result: The dictionary to put the resulting (added / updated) objects into
+            field_name: The key in result to put the (added / updated) related objects into.
+            fk_name: The column name on the related object that refers to item["id"]
+
+        Example:
+            Writer has a one-to-many relation to books.
+
+            >>> writer = {"id": 2, "name": "John Doe", "books": {"title": "Foo"}}
+            >>> _set_related_one_to_many(
+                item=writer,
+                result=writer,
+                related_gateway=BookSQLGateway,
+                field_name="books",
+                fk_name="writer_id",
+            )
+            >>> result
+            {
+                "id": 2,
+                "name": "John Doe",
+                "books": [
+                    {
+                        "id": 1",
+                        "title": "Foo",
+                        "writer_id": 2
+                    }
+                ]
+            }
+        """
+
+        # list existing related objects
+        existing_lut = {
+            x["id"]: x
+            for x in await self.filter([Filter(field=fk_name, values=[result["id"]])])
+        }
+
+        # add / update them where necessary
+        returned = []
+        for new_value in item.get(field_name, []):
+            new_value = {fk_name: result["id"], **new_value}
+            existing = existing_lut.pop(new_value.get("id"), None)
+            if existing is None:
+                returned.append(await self.add(new_value))
+            elif new_value == existing:
+                returned.append(existing)
+            else:
+                returned.append(await self.update(new_value))
+
+        result[field_name] = returned
+
+        # remove remaining
+        for to_remove in existing_lut:
+            assert await self.remove(to_remove)

+ 129 - 0
clean_python/sql_provider.py

@@ -0,0 +1,129 @@
+from abc import ABC, abstractmethod
+from contextlib import asynccontextmanager
+from typing import AsyncIterator, List
+from unittest import mock
+
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.exc import DBAPIError
+from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
+from sqlalchemy.sql import Executable
+
+from .exceptions import Conflict
+from .gateway import Json
+
+__all__ = ["SQLProvider", "SQLDatabase", "FakeSQLDatabase", "assert_query_equal"]
+
+
+def is_serialization_error(e: DBAPIError) -> bool:
+    return e.orig.args[0].startswith("<class 'asyncpg.exceptions.SerializationError'>")
+
+
+class SQLProvider(ABC):
+    @abstractmethod
+    async def execute(self, query: Executable) -> List[Json]:
+        pass
+
+    @asynccontextmanager
+    async def transaction(self) -> AsyncIterator["SQLProvider"]:
+        raise NotImplementedError()
+        yield
+
+
+class SQLDatabase(SQLProvider):
+    engine: AsyncEngine
+
+    def __init__(self, url: str, **kwargs):
+        kwargs.setdefault("isolation_level", "READ COMMITTED")
+        self.engine = create_async_engine(url, **kwargs)
+
+    async def dispose(self) -> None:
+        await self.engine.dispose()
+
+    def dispose_sync(self) -> None:
+        self.engine.sync_engine.dispose()
+
+    async def execute(self, query: Executable) -> List[Json]:
+        async with self.transaction() as transaction:
+            return await transaction.execute(query)
+
+    @asynccontextmanager
+    async def transaction(self) -> AsyncIterator[SQLProvider]:
+        async with self.engine.connect() as connection:
+            async with connection.begin():
+                yield SQLTransaction(connection)
+
+    @asynccontextmanager
+    async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
+        async with self.engine.connect() as connection:
+            async with connection.begin() as transaction:
+                yield SQLTestTransaction(connection)
+                await transaction.rollback()
+
+
+class SQLTransaction(SQLProvider):
+    def __init__(self, connection: AsyncConnection):
+        self.connection = connection
+
+    async def execute(self, query: Executable) -> List[Json]:
+        try:
+            result = await self.connection.execute(query)
+        except DBAPIError as e:
+            if is_serialization_error(e):
+                raise Conflict(str(e))
+            else:
+                raise e
+        # _asdict() is a documented method of a NamedTuple
+        # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._asdict
+        return [x._asdict() for x in result.fetchall()]
+
+
+class SQLTestTransaction(SQLTransaction):
+    @asynccontextmanager
+    async def transaction(self) -> AsyncIterator[SQLProvider]:
+        async with self.connection.begin_nested():
+            yield self
+
+
+class FakeSQLDatabase(SQLProvider):
+    def __init__(self):
+        self.queries: List[List[Executable]] = []
+        self.result = mock.Mock(return_value=[])
+
+    async def execute(self, query: Executable) -> List[Json]:
+        self.queries.append([query])
+        return self.result()
+
+    @asynccontextmanager
+    async def transaction(self) -> AsyncIterator["SQLProvider"]:
+        x = FakeSQLTransaction(result=self.result)
+        self.queries.append(x.queries)
+        yield x
+
+
+class FakeSQLTransaction(SQLProvider):
+    def __init__(self, result: mock.Mock):
+        self.queries: List[Executable] = []
+        self.result = result
+
+    async def execute(self, query: Executable) -> List[Json]:
+        self.queries.append(query)
+        return self.result()
+
+
+def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):
+    """There are two ways of 'binding' parameters (for testing!):
+
+    literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)
+    literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.
+    """
+    assert isinstance(q, Executable)
+    compiled = q.compile(
+        compile_kwargs={"literal_binds": literal_binds},
+        dialect=postgresql.dialect(),
+    )
+    if not literal_binds:
+        actual = str(compiled) % compiled.params
+    else:
+        actual = str(compiled)
+    actual = actual.replace("\n", "").replace("  ", " ")
+    assert actual == expected

+ 34 - 0
clean_python/testing.py

@@ -0,0 +1,34 @@
+from contextlib import contextmanager
+from typing import Type
+from unittest import mock
+
+from .manage import Manage
+
+
+@contextmanager
+def mock_manage(manage_cls: Type[Manage], skip=()):
+    """Mock all 'manage_' properties of a Manage class"""
+    manager = manage_cls()
+
+    mocks = {}
+    for attr_name in dir(manage_cls):
+        if not attr_name.startswith("manage_") or attr_name in skip:
+            continue
+        other_manager = getattr(manager, attr_name)
+        if not isinstance(other_manager, Manage):
+            continue
+        mocks[attr_name] = mock.MagicMock(other_manager)
+
+    patchers = [
+        mock.patch.object(
+            manage_cls,
+            name,
+            new_callable=mock.PropertyMock(return_value=x),
+        )
+        for name, x in mocks.items()
+    ]
+    for p in patchers:
+        p.start()
+    yield
+    for p in patchers:
+        p.stop()

+ 108 - 0
clean_python/tests/test_async_actor.py

@@ -0,0 +1,108 @@
+import threading
+from asyncio import BaseEventLoop
+from unittest import mock
+
+import pytest
+
+from base_lib.async_actor import (
+    async_actor,
+    AsyncActor,
+    AsyncMiddleware,
+    EventLoopThread,
+)
+
+
+@pytest.fixture
+def started_thread():
+    thread = EventLoopThread()
+    thread.start()
+    yield thread
+    thread.join()
+
+
+def test_event_loop_thread_start():
+    try:
+        thread = EventLoopThread()
+        thread.start()
+        assert isinstance(thread.loop, BaseEventLoop)
+        assert thread.loop.is_running()
+    finally:
+        thread.join()
+
+
+def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread):
+    result = {}
+
+    async def get_thread_id():
+        result["thread_id"] = threading.get_ident()
+
+    started_thread.run_coroutine(get_thread_id())
+
+    # the coroutine executed in the event loop thread
+    assert result["thread_id"] == started_thread.ident
+
+
+def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread):
+    async def raise_error():
+        raise TypeError("bla")
+
+    coro = raise_error()
+
+    with pytest.raises(TypeError, match="bla"):
+        started_thread.run_coroutine(coro)
+
+
+@mock.patch.object(EventLoopThread, "start")
+@mock.patch.object(EventLoopThread, "run_coroutine")
+def test_async_middleware_before_worker_boot(
+    EventLoopThread_run_coroutine, EventLoopThread_start
+):
+    broker = mock.Mock()
+    worker = mock.Mock()
+    middleware = AsyncMiddleware()
+
+    middleware.before_worker_boot(broker, worker)
+
+    assert isinstance(middleware.event_loop_thread, EventLoopThread)
+
+    EventLoopThread_start.assert_called_once()
+
+    middleware.run_coroutine("foo")
+    EventLoopThread_run_coroutine.assert_called_once_with("foo")
+
+    # broker was patched with run_coroutine
+    broker.run_coroutine("bar")
+    EventLoopThread_run_coroutine.assert_called_with("bar")
+
+
+def test_async_middleware_after_worker_shutdown():
+    broker = mock.Mock()
+    broker.run_coroutine = lambda x: x
+    worker = mock.Mock()
+    event_loop_thread = mock.Mock()
+
+    middleware = AsyncMiddleware()
+    middleware.event_loop_thread = event_loop_thread
+    middleware.after_worker_shutdown(broker, worker)
+
+    event_loop_thread.join.assert_called_once()
+    assert middleware.event_loop_thread is None
+    assert not hasattr(broker, "run_coroutine")
+
+
+def test_async_actor():
+    broker = mock.Mock()
+    broker.actor_options = {"max_retries"}
+
+    @async_actor(broker=broker)
+    async def foo(*args, **kwargs):
+        pass
+
+    assert isinstance(foo, AsyncActor)
+
+    foo(2, a="b")
+
+    broker.run_coroutine.assert_called_once()
+
+    # no recursion errors here:
+    repr(foo)

+ 34 - 0
clean_python/tests/test_celery_rmq_broker.py

@@ -0,0 +1,34 @@
+from unittest import mock
+
+import pytest
+
+from base_lib import CeleryRmqBroker
+
+
+@pytest.fixture
+def celery_rmq_broker():
+    return CeleryRmqBroker("amqp://rmq:1234//", "some_queue", "host", False)
+
+
+@mock.patch("base_lib.celery_rmq_broker.pika.BlockingConnection")
+async def test_celery_rmq_broker(connection, celery_rmq_broker):
+    await celery_rmq_broker.add({"task": "some.task", "args": ["foo", 15]})
+
+    channel = connection().__enter__().channel()
+
+    _, call_kwargs = channel.basic_publish.call_args
+
+    assert call_kwargs["exchange"] == ""
+    assert call_kwargs["routing_key"] == "some_queue"
+    assert call_kwargs["body"] == '[["foo", 15], {}, null]'
+    task_id = call_kwargs["properties"].correlation_id
+
+    assert call_kwargs["properties"].headers["id"] == task_id
+    assert call_kwargs["properties"].headers["root_id"] == task_id
+    assert call_kwargs["properties"].headers["parent_id"] is None
+    assert call_kwargs["properties"].headers["group"] is None
+    assert call_kwargs["properties"].headers["lang"] == "py"
+    assert call_kwargs["properties"].headers["task"] == "some.task"
+    assert call_kwargs["properties"].headers["origin"] == "host"
+    assert call_kwargs["properties"].headers["argsrepr"] == '["foo", 15]'
+    assert call_kwargs["properties"].headers["kwargsrepr"] == "{}"

+ 83 - 0
clean_python/tests/test_dramatiq_task_logger.py

@@ -0,0 +1,83 @@
+import os
+from unittest import mock
+
+import pytest
+from dramatiq.errors import Retry
+from dramatiq.message import Message
+
+from base_lib.dramatiq_task_logger import DramatiqTaskLogger
+from base_lib.gateway import InMemoryGateway
+
+
+@pytest.fixture
+def in_memory_gateway():
+    return InMemoryGateway(data=[])
+
+
+@pytest.fixture
+def task_logger(in_memory_gateway):
+    return DramatiqTaskLogger(
+        hostname="host",
+        gateway_override=in_memory_gateway,
+    )
+
+
+@pytest.fixture
+def message():
+    return Message(
+        queue_name="default",
+        actor_name="my_task",
+        args=(1, 2),
+        kwargs={"foo": "bar"},
+        options={},
+        message_id="abc123",
+        message_timestamp=None,
+    )
+
+
+@pytest.fixture
+def expected():
+    return {
+        "id": 1,
+        "tag_suffix": "task_log",
+        "task_id": "abc123",
+        "name": "my_task",
+        "state": "SUCCESS",
+        "duration": 0,
+        "retries": 0,
+        "origin": f"host-{os.getpid()}",
+        "argsrepr": b"[1,2]",
+        "kwargsrepr": b'{"foo":"bar"}',
+        "result": None,
+    }
+
+
+@mock.patch("time.time", return_value=123)
+async def test_log_success(time, task_logger, in_memory_gateway, message, expected):
+    await task_logger.start()
+    await task_logger.stop(message)
+
+    assert in_memory_gateway.data[1] == expected
+
+
+@mock.patch("time.time", new=mock.Mock(return_value=123))
+async def test_log_fail(task_logger, in_memory_gateway, message, expected):
+    await task_logger.start()
+    await task_logger.stop(message, exception=ValueError("test"))
+
+    assert in_memory_gateway.data[1] == {
+        **expected,
+        "state": "FAILURE",
+        "result": None,
+    }
+
+
+@mock.patch("time.time", return_value=123)
+async def test_log_retry(time, task_logger, in_memory_gateway, message, expected):
+    await task_logger.start()
+    await task_logger.stop(message, exception=Retry("test"))
+
+    assert in_memory_gateway.data[1] == {
+        **expected,
+        "state": "RETRY",
+    }

+ 32 - 0
clean_python/tests/test_exceptions.py

@@ -0,0 +1,32 @@
+from pydantic import ValidationError
+
+from base_lib import ValueObject
+from base_lib.exceptions import BadRequest, DoesNotExist
+
+
+def test_bad_request_short_str():
+    e = BadRequest("bla bla bla")
+    assert str(e) == "validation error: bla bla bla"
+
+
+def test_does_not_exist_str():
+    e = DoesNotExist("raster", id=12)
+    assert str(e) == "does not exist: raster with id=12"
+
+
+def test_does_not_exist_no_id_str():
+    e = DoesNotExist("raster")
+    assert str(e) == "does not exist: raster"
+
+
+class Book(ValueObject):
+    title: str
+
+
+def test_bad_request_from_validation_error():
+    try:
+        Book()
+    except ValidationError as e:
+        err = BadRequest(e)
+
+    assert str(err) == "validation error: 'title' field required"

+ 152 - 0
clean_python/tests/test_fastapi_access_logger.py

@@ -0,0 +1,152 @@
+from unittest import mock
+
+import pytest
+from fastapi.routing import APIRoute
+from starlette.requests import Request
+from starlette.responses import JSONResponse, StreamingResponse
+
+from base_lib import FastAPIAccessLogger, InMemoryGateway
+
+
+@pytest.fixture
+def fastapi_access_logger():
+    return FastAPIAccessLogger(hostname="myhost", gateway_override=InMemoryGateway([]))
+
+
+@pytest.fixture
+def req():
+    # a copy-paste from a local session, with some values removed / shortened
+    scope = {
+        "type": "http",
+        "asgi": {"version": "3.0", "spec_version": "2.3"},
+        "http_version": "1.1",
+        "server": ("172.20.0.6", 80),
+        "client": ("172.20.0.1", 45584),
+        "scheme": "http",
+        "root_path": "/v1-beta",
+        "headers": [
+            (b"host", b"localhost:8000"),
+            (b"connection", b"keep-alive"),
+            (b"accept", b"application/json"),
+            (b"authorization", b"..."),
+            (b"user-agent", b"Mozilla/5.0 ..."),
+            (b"referer", b"http://localhost:8000/v1-beta/docs"),
+            (b"accept-encoding", b"gzip, deflate, br"),
+            (b"accept-language", b"en-US,en;q=0.9"),
+            (b"cookie", b"..."),
+        ],
+        "state": {},
+        "method": "GET",
+        "path": "/rasters",
+        "raw_path": b"/v1-beta/rasters",
+        "query_string": b"limit=50&offset=0&order_by=id",
+        "path_params": {},
+        "app_root_path": "",
+        "route": APIRoute(
+            endpoint=lambda x: x,
+            path="/rasters",
+            name="v1-beta/raster_list",
+            methods=["GET"],
+        ),
+    }
+    return Request(scope)
+
+
+@pytest.fixture
+def response():
+    return JSONResponse({"foo": "bar"})
+
+
+@pytest.fixture
+def call_next(response):
+    async def func(request):
+        return response
+
+    return func
+
+
+@mock.patch("time.time", return_value=0.0)
+async def test_logging(time, fastapi_access_logger, req, response, call_next):
+    await fastapi_access_logger(req, call_next)
+    assert len(fastapi_access_logger.gateway.data) == 0
+    await response.background()
+    (actual,) = fastapi_access_logger.gateway.data.values()
+    actual.pop("id")
+    assert actual == {
+        "tag_suffix": "access_log",
+        "remote_address": "172.20.0.1",
+        "method": "GET",
+        "path": "/v1-beta/rasters",
+        "portal": "localhost:8000",
+        "referer": "http://localhost:8000/v1-beta/docs",
+        "user_agent": "Mozilla/5.0 ...",
+        "query_params": "limit=50&offset=0&order_by=id",
+        "view_name": "v1-beta/raster_list",
+        "status": 200,
+        "content_type": "application/json",
+        "content_length": 13,
+        "time": "1970-01-01T00:00:00Z",
+        "request_time": 0.0,
+    }
+
+
+@pytest.fixture
+def req_minimal():
+    # https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
+    scope = {
+        "type": "http",
+        "asgi": {"version": "3.0"},
+        "http_version": "1.1",
+        "method": "GET",
+        "scheme": "http",
+        "path": "/",
+        "query_string": "",
+        "headers": [],
+    }
+    return Request(scope)
+
+
+@pytest.fixture
+def streaming_response():
+    async def numbers(minimum, maximum):
+        yield ("<html><body><ul>")
+        for number in range(minimum, maximum + 1):
+            yield "<li>%d</li>" % number
+        yield ("</ul></body></html>")
+
+    return StreamingResponse(numbers(1, 3), media_type="text/html")
+
+
+@pytest.fixture
+def call_next_streaming(streaming_response):
+    async def func(request):
+        return streaming_response
+
+    return func
+
+
+@mock.patch("time.time", return_value=0.0)
+async def test_logging_minimal(
+    time, fastapi_access_logger, req_minimal, streaming_response, call_next_streaming
+):
+    await fastapi_access_logger(req_minimal, call_next_streaming)
+    assert len(fastapi_access_logger.gateway.data) == 0
+    await streaming_response.background()
+    (actual,) = fastapi_access_logger.gateway.data.values()
+    actual.pop("id")
+    assert actual == {
+        "tag_suffix": "access_log",
+        "remote_address": None,
+        "method": "GET",
+        "path": "/",
+        "portal": "",
+        "referer": None,
+        "user_agent": None,
+        "query_params": "",
+        "view_name": None,
+        "status": 200,
+        "content_type": "text/html; charset=utf-8",
+        "content_length": None,
+        "time": "1970-01-01T00:00:00Z",
+        "request_time": 0.0,
+    }

+ 173 - 0
clean_python/tests/test_gateway.py

@@ -0,0 +1,173 @@
+from datetime import datetime, timezone
+from unittest import mock
+
+import pytest
+
+from base_lib import (
+    AlreadyExists,
+    Conflict,
+    DoesNotExist,
+    Filter,
+    InMemoryGateway,
+    PageOptions,
+)
+
+
+@pytest.fixture
+def in_memory_gateway():
+    return InMemoryGateway(
+        data=[
+            {"id": 1, "name": "a"},
+            {"id": 2, "name": "b"},
+            {"id": 3, "name": "c"},
+        ]
+    )
+
+
+async def test_get(in_memory_gateway):
+    actual = await in_memory_gateway.get(1)
+    assert actual == in_memory_gateway.data[1]
+
+
+async def test_get_none(in_memory_gateway):
+    actual = await in_memory_gateway.get(4)
+    assert actual is None
+
+
+async def test_add(in_memory_gateway):
+    record = {"id": 5, "name": "d"}
+    await in_memory_gateway.add(record)
+    assert in_memory_gateway.data[5] == record
+
+
+async def test_add_id_autoincrement(in_memory_gateway):
+    record = {"name": "d"}
+    await in_memory_gateway.add(record)
+    assert in_memory_gateway.data[4] == {"id": 4, "name": "d"}
+
+
+async def test_add_id_exists(in_memory_gateway):
+    with pytest.raises(AlreadyExists):
+        await in_memory_gateway.add({"id": 3})
+
+
+async def test_update(in_memory_gateway):
+    record = {"id": 3, "name": "d"}
+    await in_memory_gateway.update(record)
+    assert in_memory_gateway.data[3] == record
+
+
+async def test_update_no_id(in_memory_gateway):
+    with pytest.raises(DoesNotExist):
+        await in_memory_gateway.update({"no": "id"})
+
+
+async def test_update_does_not_exist(in_memory_gateway):
+    with pytest.raises(DoesNotExist):
+        await in_memory_gateway.update({"id": 4})
+
+
+async def test_upsert(in_memory_gateway):
+    record = {"id": 3, "name": "d"}
+    await in_memory_gateway.upsert(record)
+    assert in_memory_gateway.data[3] == record
+
+
+async def test_upsert_no_id(in_memory_gateway):
+    await in_memory_gateway.upsert({"name": "x"})
+    assert in_memory_gateway.data[4] == {"id": 4, "name": "x"}
+
+
+async def test_upsert_does_add(in_memory_gateway):
+    await in_memory_gateway.upsert({"id": 4, "name": "x"})
+    assert in_memory_gateway.data[4] == {"id": 4, "name": "x"}
+
+
+async def test_remove(in_memory_gateway):
+    assert await in_memory_gateway.remove(1)
+    assert 1 not in in_memory_gateway.data
+    assert len(in_memory_gateway.data) == 2
+
+
+async def test_remove_not_existing(in_memory_gateway):
+    assert not await in_memory_gateway.remove(4)
+    assert len(in_memory_gateway.data) == 3
+
+
+async def test_updated_if_unmodified_since(in_memory_gateway):
+    existing = {"id": 4, "name": "e", "updated_at": datetime.now(timezone.utc)}
+    new = {"id": 4, "name": "f", "updated_at": datetime.now(timezone.utc)}
+
+    await in_memory_gateway.add(existing)
+
+    await in_memory_gateway.update(new, if_unmodified_since=existing["updated_at"])
+    assert in_memory_gateway.data[4]["name"] == "f"
+
+
+@pytest.mark.parametrize(
+    "if_unmodified_since", [datetime.now(timezone.utc), datetime(2010, 1, 1)]
+)
+async def test_update_if_unmodified_since_not_ok(
+    in_memory_gateway, if_unmodified_since
+):
+    existing = {"id": 4, "name": "e", "updated_at": datetime.now(timezone.utc)}
+    new = {"id": 4, "name": "f", "updated_at": datetime.now(timezone.utc)}
+
+    await in_memory_gateway.add(existing)
+    with pytest.raises(Conflict):
+        await in_memory_gateway.update(new, if_unmodified_since=if_unmodified_since)
+
+
+async def test_filter_all(in_memory_gateway):
+    actual = await in_memory_gateway.filter([])
+    assert actual == sorted(in_memory_gateway.data.values(), key=lambda x: x["id"])
+
+
+async def test_filter_all_with_params(in_memory_gateway):
+    actual = await in_memory_gateway.filter(
+        [], params=PageOptions(limit=2, offset=1, order_by="id", ascending=False)
+    )
+    assert [x["id"] for x in actual] == [2, 1]
+
+
+async def test_filter(in_memory_gateway):
+    actual = await in_memory_gateway.filter([Filter(field="name", values=["b"])])
+    assert actual == [in_memory_gateway.data[2]]
+
+
+async def test_count_all(in_memory_gateway):
+    actual = await in_memory_gateway.count([])
+    assert actual == 3
+
+
+async def test_count_with_filter(in_memory_gateway):
+    actual = await in_memory_gateway.count([Filter(field="name", values=["b"])])
+    assert actual == 1
+
+
+@mock.patch.object(InMemoryGateway, "update")
+async def test_update_transactional(update):
+    record = {"id": 3, "name": "d", "updated_at": datetime(2010, 1, 1)}
+    gateway = InMemoryGateway([record])
+    await gateway.update_transactional(3, lambda x: {"name": x["name"] + "x"})
+
+    update.assert_awaited_once_with(
+        {"name": "dx"}, if_unmodified_since=datetime(2010, 1, 1)
+    )
+
+
+async def test_update_transactional_does_not_exist(in_memory_gateway):
+    with pytest.raises(DoesNotExist):
+        await in_memory_gateway.update_transactional(5, lambda x: x)
+
+
+async def test_exists_all(in_memory_gateway):
+    assert await in_memory_gateway.exists([])
+
+
+async def test_exists_with_filter(in_memory_gateway):
+    assert await in_memory_gateway.exists([Filter(field="name", values=["b"])])
+
+
+async def test_exists_with_filter_not(in_memory_gateway):
+    assert not await in_memory_gateway.exists([Filter(field="name", values=["bb"])])

+ 124 - 0
clean_python/tests/test_internal_gateway.py

@@ -0,0 +1,124 @@
+from typing import cast
+
+import pytest
+from pydantic import Field
+
+from base_lib import (
+    Filter,
+    InMemoryGateway,
+    InternalGateway,
+    Manage,
+    Repository,
+    RootEntity,
+    ValueObject,
+)
+
+
+# domain - other module
+class User(RootEntity):
+    name: str = Field(min_length=1)
+
+
+class UserRepository(Repository[User]):
+    pass
+
+
+# application - other module
+class ManageUser(Manage[User]):
+    def __init__(self):
+        self.repo = UserRepository(gateway=InMemoryGateway([]))
+
+
+# domain - this module
+class UserObj(ValueObject):
+    id: int
+    name: str
+
+
+# infrastructure - this module
+
+
+class UserGateway(InternalGateway[User, UserObj]):
+    def __init__(self, manage: ManageUser):
+        self._manage = manage
+
+    @property
+    def manage(self) -> ManageUser:
+        return self._manage
+
+    def _map(self, obj: User) -> UserObj:
+        return UserObj(id=cast(int, obj.id), name=obj.name)
+
+
+@pytest.fixture
+def internal_gateway():
+    return UserGateway(manage=ManageUser())
+
+
+async def test_get_not_existing(internal_gateway: UserGateway):
+    assert await internal_gateway.get(1) is None
+
+
+async def test_add(internal_gateway: UserGateway):
+    actual = await internal_gateway.add(UserObj(id=12, name="foo"))
+
+    assert actual == UserObj(id=12, name="foo")
+
+
+@pytest.fixture
+async def internal_gateway_with_record(internal_gateway):
+    await internal_gateway.add(UserObj(id=12, name="foo"))
+    return internal_gateway
+
+
+async def test_get(internal_gateway_with_record):
+    assert await internal_gateway_with_record.get(12) == UserObj(id=12, name="foo")
+
+
+async def test_filter(internal_gateway_with_record: UserGateway):
+    assert await internal_gateway_with_record.filter([]) == [UserObj(id=12, name="foo")]
+
+
+async def test_filter_2(internal_gateway_with_record: UserGateway):
+    assert (
+        await internal_gateway_with_record.filter([Filter(field="id", values=[1])])
+        == []
+    )
+
+
+async def test_remove(internal_gateway_with_record: UserGateway):
+    assert await internal_gateway_with_record.remove(12)
+
+    assert internal_gateway_with_record.manage.repo.gateway.data == {}
+
+
+async def test_remove_does_not_exist(internal_gateway: UserGateway):
+    assert not await internal_gateway.remove(12)
+
+
+async def test_add_bad_request(internal_gateway: UserGateway):
+    # a 'bad request' should be reraised as a ValueError; errors in gateways
+    # are an internal affair.
+    with pytest.raises(ValueError):
+        await internal_gateway.add(UserObj(id=12, name=""))
+
+
+async def test_count(internal_gateway_with_record: UserGateway):
+    assert await internal_gateway_with_record.count([]) == 1
+
+
+async def test_count_2(internal_gateway_with_record: UserGateway):
+    assert (
+        await internal_gateway_with_record.count([Filter(field="id", values=[1])]) == 0
+    )
+
+
+async def test_exists(internal_gateway_with_record: UserGateway):
+    assert await internal_gateway_with_record.exists([]) is True
+
+
+async def test_exists_2(internal_gateway_with_record: UserGateway):
+    assert (
+        await internal_gateway_with_record.exists([Filter(field="id", values=[1])])
+        is False
+    )

+ 95 - 0
clean_python/tests/test_manage.py

@@ -0,0 +1,95 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+from unittest import mock
+
+import pytest
+
+from base_lib import Filter, Manage, RootEntity
+
+
+class User(RootEntity):
+    name: str
+
+
+class ManageUser(Manage[User]):
+    def __init__(self):
+        self.repo = mock.AsyncMock()
+
+
+@pytest.fixture
+def manage_user():
+    return ManageUser()
+
+
+async def test_retrieve(manage_user):
+    result = await manage_user.retrieve(2)
+
+    manage_user.repo.get.assert_awaited_with(2)
+    assert result is manage_user.repo.get.return_value
+
+
+async def test_create(manage_user):
+    result = await manage_user.create({"name": "piet"})
+
+    manage_user.repo.add.assert_awaited_once_with({"name": "piet"})
+
+    assert result is manage_user.repo.add.return_value
+
+
+async def test_update(manage_user):
+    manage_user.repo.get.return_value = User.create(id=2, name="piet")
+
+    result = await manage_user.update(2, {"name": "jan"})
+
+    manage_user.repo.update.assert_awaited_once_with(2, {"name": "jan"})
+
+    assert result is manage_user.repo.update.return_value
+
+
+async def test_destroy(manage_user):
+    result = await manage_user.destroy(2)
+
+    manage_user.repo.remove.assert_awaited_with(2)
+    assert result is manage_user.repo.remove.return_value
+
+
+async def test_list(manage_user):
+    result = await manage_user.list()
+
+    manage_user.repo.all.assert_awaited_once()
+    assert result is manage_user.repo.all.return_value
+
+
+async def test_by(manage_user):
+    result = await manage_user.by("name", "piet")
+
+    manage_user.repo.by.assert_awaited_with("name", "piet", params=None)
+    assert result is manage_user.repo.by.return_value
+
+
+async def test_filter(manage_user):
+    filters = [Filter(field="x", values=[1])]
+    result = await manage_user.filter(filters)
+
+    manage_user.repo.filter.assert_awaited_once_with(filters, params=None)
+
+    assert result is manage_user.repo.filter.return_value
+
+
+async def test_count(manage_user):
+    filters = [Filter(field="x", values=[1])]
+    result = await manage_user.count(filters)
+
+    manage_user.repo.count.assert_awaited_once_with(filters)
+
+    assert result is manage_user.repo.count.return_value
+
+
+async def test_exists(manage_user):
+    filters = [Filter(field="x", values=[1])]
+    result = await manage_user.exists(filters)
+
+    manage_user.repo.exists.assert_awaited_once_with(filters)
+
+    assert result is manage_user.repo.exists.return_value

+ 125 - 0
clean_python/tests/test_oauth2.py

@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+
+import time
+from unittest import mock
+
+import jwt
+import pytest
+
+from base_lib import OAuth2AccessTokenVerifier, PermissionDenied, Unauthorized
+
+
+@pytest.fixture
+def settings():
+    return {
+        "issuer": "https://cognito-idp.region.amazonaws.com/region_abc123",
+        "resource_server_id": "localhost/",
+        "algorithms": ["RS256"],
+        "admin_users": ["foo"],
+    }
+
+
+@pytest.fixture
+def private_key():
+    # this key was generated especially for this test suite; is has no other applications
+    return {
+        "p": "_PgJBxrGEy8I5KvY_nDRT9loaBPqHHn0AUiTa92zBrAX0qA8ZhV66pUkX2JehU3efduel4FOK2xx-W31p7kCLoaGsMtfKAPYC33KptCH9YXkeMQHq1jWfcRgAVXpdXc7M4pQxO8Dh2BU8qhtAzhpbP4tUPoLIGcTUGd-1ieDkqE",  # NOQA
+        "kty": "RSA",
+        "q": "hT0USPCNN4o2PauND53ubh2G5uOHzY9mfPuEXZ1fiRihCe5Bng0K8Pzx5QpSAjUY2-FhHa8jK8ITERmwT3MQKJpmlm_1R8GnaNVPOj8BpAhDlMzgkVikEGj0Pd7x_wdSko7KscyG-ZVsMw_KiCZpC6hMiI60w9GG14MtXhRVWhM",  # NOQA
+        "d": "BNwTHorPcAMiDglxt5Ylz1jqQ67rYcnA0okvZxz0QPbLovuTM1WIaPIeGlqXNzB9NxXtZhHXtnhoSwPf2LxMmYWWgJLqhPQWRlqZhLhww0nGGUgk_b1gNnMQuuh2weLfPNUksddhDJHzW1pBiDQrhP0t064Pz_P8WtGUkBka5-Pb3pItaF_w4xDIhhTJS48kv5H-BrwK8Vlz-EofkmPgxXBvCwhVoXZihxEUVzc6X59e1UiymXr-3lbNeL-76Yb9JHJFjXh2o52v5eZDVT6ir-iUp7bBXTiZsFaBCUCfCjx3MiQkHNBNEV7Cr9DKvfGdK3r9IbkSAC1tiD4Y1oyZwQ",  # NOQA
+        "e": "AQAB",
+        "use": "sig",
+        "kid": "_Lfex-skFCKBZd0xMN5dZSAX7uoG6LMx3i2qHReqU0c",
+        "qi": "GNhYuNdxd4NyRhzreW72PWXzj2oIkm0rIHrcNW9bpqK1fxrsbiVUEVUly-cqpD_-AjFOyCWcKWQxHG7J8LeP2vW3_U4TLx_jKD9cc7S65gb37El1ihOwNWbapRxToOhP2sZa0g3y9P-M_8hQcfKr1OFMQMnD9wj-sVNw9yJf3I4",  # NOQA
+        "dp": "xTs6BrEISEK-w1N9Dvy1JXWToroMKQGojiugzVQAVjGLkWvfS5RpzmZUAo52taZ911EZOHTXlqGpx1jFVGy5176JW2RlH5THqEX-b8tchcBL3yCv_hd4vHwUglYSfMRmgwvPZ4wXC0C_WqaYwA8Gm7UdbepWLIBRHbpjuOL8AaE",  # NOQA
+        "alg": "RS256",
+        "dq": "C4_UTcwKBRLKSCm10PAce5O2XBzMcQsLkrbkspbwbl4jw0_Yg9WP6H-aogx2N1jSMmppWgETpT1vGCHJietrMIrNcip-914Xn-I6wMws4UYSTzxEFHjDq-TfpOrOxxmkkbEwZ6Ne5xOPUxMAuTXUEb3l_keb6g4pjFQGwM405d8",  # NOQA
+        "n": "g6k31kvFdTaCSxXhazC5JaVekYi836F0H_YLrDioQlwiegsGjUDYk5TM7z8iXwDIm0QZZgtoEBlEny8vXrt1WGMO8GGwnVNq0_ZAD3JYp-a_c0X7VM7I2Dze32zcy8mC4QhPedEbMVDzi1XrusGjNHWObkMKsLZ7RRlwdkgR4nRpzncou_2ZJLvc50C8tjd3juCpUMWXNsvDjoAenxoXs68SDK4h9QSjvaWaSHNRGYiYkGUvcL5rv3htbrHIUVAcBC9r0j5Ued1hBR9ND1KPxVJWnn8oRAxFrYIcQdaDFWnWdb5BY9pJQls9fHlt0PF9vXUm-GufWk0U8D4Lc8V78w",  # NOQA
+    }
+
+
+@pytest.fixture
+def public_key(private_key):
+    keys = ("alg", "e", "kid", "kty", "n", "use")
+    return {k: private_key[k] for k in keys}
+
+
+@pytest.fixture
+def patched_verifier(public_key, settings):
+    verifier = OAuth2AccessTokenVerifier(scope="all", **settings)
+    with mock.patch.object(verifier, "jwk_client") as jwk_client:
+        jwk_client.get_signing_key_from_jwt.return_value = jwt.PyJWK.from_dict(
+            public_key
+        )
+        yield verifier
+
+
+@pytest.fixture
+def token_generator(private_key, settings):
+    default_claims = {
+        "sub": "foo",
+        "iss": settings["issuer"],
+        "scope": f"{settings['resource_server_id']}all",
+        "token_use": "access",
+        "exp": int(time.time()) + 3600,
+        "iat": int(time.time()) - 3600,
+        "nbf": int(time.time()) - 3600,
+    }
+
+    def generate_token(**claim_overrides):
+        claims = {**default_claims, **claim_overrides}
+        claims = {k: v for (k, v) in claims.items() if v is not None}
+        return jwt.encode(
+            claims,
+            key=jwt.PyJWK.from_dict(private_key).key,
+            algorithm=private_key["alg"],
+            headers={"kid": private_key["kid"]},
+        )
+
+    return generate_token
+
+
+def test_verifier_ok(patched_verifier, token_generator):
+    token = token_generator()
+    verified_claims = patched_verifier(token)
+    assert verified_claims == jwt.decode(token, options={"verify_signature": False})
+
+    patched_verifier.jwk_client.get_signing_key_from_jwt.assert_called_once_with(token)
+
+
+def test_verifier_exp_leeway(patched_verifier, token_generator):
+    token = token_generator(exp=int(time.time()) - 60)
+    patched_verifier(token)
+
+
+def test_verifier_multiple_scopes(patched_verifier, token_generator, settings):
+    token = token_generator(scope=f"scope1 {settings['resource_server_id']}all scope3")
+    patched_verifier(token)
+
+
+@pytest.mark.parametrize(
+    "claim_overrides",
+    [
+        {"iss": "https://authserver"},
+        {"iss": None},
+        {"scope": "nothing"},
+        {"scope": None},
+        {"exp": int(time.time()) - 3600},
+        {"exp": None},
+        {"nbf": int(time.time()) + 3600},
+        {"token_use": "id"},
+        {"token_use": None},
+        {"sub": None},
+    ],
+)
+def test_verifier_bad(patched_verifier, token_generator, claim_overrides):
+    token = token_generator(**claim_overrides)
+    with pytest.raises(Unauthorized):
+        patched_verifier(token)
+
+
+def test_verifier_authorize(patched_verifier, token_generator):
+    token = token_generator(sub="bar")
+    with pytest.raises(PermissionDenied):
+        patched_verifier(token)

+ 179 - 0
clean_python/tests/test_repository.py

@@ -0,0 +1,179 @@
+from unittest import mock
+
+import pytest
+
+from base_lib import (
+    BadRequest,
+    DoesNotExist,
+    Filter,
+    InMemoryGateway,
+    Page,
+    PageOptions,
+    Repository,
+    RootEntity,
+)
+
+
+class User(RootEntity):
+    name: str
+
+
+@pytest.fixture
+def users():
+    return [
+        User.create(id=1, name="a"),
+        User.create(id=2, name="b"),
+        User.create(id=3, name="c"),
+    ]
+
+
+class UserRepository(Repository[User]):
+    pass
+
+
+@pytest.fixture
+def user_repository(users):
+    return UserRepository(gateway=InMemoryGateway(data=[x.dict() for x in users]))
+
+
+@pytest.fixture
+def page_options():
+    return PageOptions(limit=10, offset=0, order_by="id")
+
+
+def test_entity_attr(user_repository):
+    assert user_repository.entity is User
+
+
+async def test_get(user_repository):
+    actual = await user_repository.get(1)
+    assert actual.name == "a"
+
+
+async def test_get_does_not_exist(user_repository):
+    with pytest.raises(DoesNotExist):
+        await user_repository.get(4)
+
+
+@mock.patch.object(Repository, "filter")
+async def test_all(filter_m, user_repository, page_options):
+    filter_m.return_value = Page(total=0, items=[])
+    assert await user_repository.all(page_options) is filter_m.return_value
+
+    filter_m.assert_awaited_once_with([], params=page_options)
+
+
+async def test_add(user_repository):
+    actual = await user_repository.add(User.create(name="d"))
+    assert actual.name == "d"
+    assert user_repository.gateway.data[4] == actual.dict()
+
+
+async def test_add_json(user_repository):
+    actual = await user_repository.add({"name": "d"})
+    assert actual.name == "d"
+    assert user_repository.gateway.data[4] == actual.dict()
+
+
+async def test_add_json_validates(user_repository):
+    with pytest.raises(BadRequest):
+        await user_repository.add({"id": "d"})
+
+
+async def test_update(user_repository):
+    actual = await user_repository.update(id=2, values={"name": "d"})
+    assert actual.name == "d"
+    assert user_repository.gateway.data[2] == actual.dict()
+
+
+async def test_update_does_not_exist(user_repository):
+    with pytest.raises(DoesNotExist):
+        await user_repository.update(id=4, values={"name": "d"})
+
+
+async def test_update_validates(user_repository):
+    with pytest.raises(BadRequest):
+        await user_repository.update(id=2, values={"id": 6})
+
+
+async def test_remove(user_repository):
+    assert await user_repository.remove(2)
+    assert 2 not in user_repository.gateway.data
+
+
+async def test_remove_does_not_exist(user_repository):
+    assert not await user_repository.remove(4)
+
+
+async def test_upsert_updates(user_repository):
+    actual = await user_repository.upsert(User.create(id=2, name="d"))
+    assert actual.name == "d"
+    assert user_repository.gateway.data[2] == actual.dict()
+
+
+async def test_upsert_adds(user_repository):
+    actual = await user_repository.upsert(User.create(id=4, name="d"))
+    assert actual.name == "d"
+    assert user_repository.gateway.data[4] == actual.dict()
+
+
+@mock.patch.object(InMemoryGateway, "count")
+async def test_filter(count_m, user_repository, users):
+    actual = await user_repository.filter([Filter(field="name", values=["b"])])
+    assert actual == Page(total=1, items=[users[1]], limit=None, offest=None)
+    assert not count_m.called
+
+
+@mock.patch.object(InMemoryGateway, "count")
+async def test_filter_with_pagination(count_m, user_repository, users, page_options):
+    actual = await user_repository.filter(
+        [Filter(field="name", values=["b"])], page_options
+    )
+    assert actual == Page(
+        total=1, items=[users[1]], limit=page_options.limit, offset=page_options.offset
+    )
+    assert not count_m.called
+
+
+@pytest.mark.parametrize(
+    "page_options",
+    [
+        PageOptions(limit=3, offset=0, order_by="id"),
+        PageOptions(limit=10, offset=1, order_by="id"),
+    ],
+)
+@mock.patch.object(InMemoryGateway, "count")
+async def test_filter_with_pagination_calls_count(
+    count_m, user_repository, users, page_options
+):
+    count_m.return_value = 123
+    actual = await user_repository.filter([], page_options)
+    assert actual == Page(
+        total=count_m.return_value,
+        items=users[page_options.offset :],
+        limit=page_options.limit,
+        offset=page_options.offset,
+    )
+    assert count_m.called
+
+
+@mock.patch.object(Repository, "filter")
+async def test_by(filter_m, user_repository, page_options):
+    filter_m.return_value = Page(total=0, items=[])
+    assert await user_repository.by("name", "b", page_options) is filter_m.return_value
+
+    filter_m.assert_awaited_once_with(
+        [Filter(field="name", values=["b"])], params=page_options
+    )
+
+
+@mock.patch.object(InMemoryGateway, "count")
+async def test_count(gateway_count, user_repository):
+    assert await user_repository.count("foo") is gateway_count.return_value
+    gateway_count.assert_awaited_once_with("foo")
+
+
+@mock.patch.object(InMemoryGateway, "exists")
+async def test_exists(gateway_exists, user_repository):
+    assert await user_repository.exists("foo") is gateway_exists.return_value
+    gateway_exists.assert_awaited_once_with("foo")

+ 46 - 0
clean_python/tests/test_request_query.py

@@ -0,0 +1,46 @@
+from typing import Optional
+
+import pytest
+from pydantic import ValidationError
+
+from base_lib.gateway import Filter
+from base_lib.pagination import PageOptions
+from base_lib.request_query import RequestQuery
+
+
+class SomeQuery(RequestQuery):
+    foo: Optional[int] = None
+
+
+@pytest.mark.parametrize(
+    "query,expected",
+    [
+        (
+            RequestQuery(),
+            PageOptions(limit=50, offset=0, order_by="id", ascending=True),
+        ),
+        (
+            RequestQuery(limit=10, offset=20, order_by="-id"),
+            PageOptions(limit=10, offset=20, order_by="id", ascending=False),
+        ),
+    ],
+)
+def test_as_page_options(query, expected):
+    assert query.as_page_options() == expected
+
+
+@pytest.mark.parametrize(
+    "query,expected",
+    [
+        (SomeQuery(), []),
+        (SomeQuery(foo=None), []),
+        (SomeQuery(foo=3), [Filter(field="foo", values=[3])]),
+    ],
+)
+def test_filters(query, expected):
+    assert query.filters() == expected
+
+
+def test_validate_order_by():
+    with pytest.raises(ValidationError):
+        RequestQuery(limit=10, offset=0, order_by="foo")

+ 132 - 0
clean_python/tests/test_resource.py

@@ -0,0 +1,132 @@
+import pytest
+from fastapi.routing import APIRouter
+
+from base_lib.resource import APIVersion, get, Resource, Stability, v
+
+
+def test_subclass():
+    class Cls(Resource, version=v(1), name="foo"):
+        pass
+
+    for obj in (Cls, Cls()):
+        assert obj.name == "foo"
+        assert obj.version == v(1)
+
+
+def test_get_router_no_endpoints():
+    class Cls(Resource, version=v(1)):
+        pass
+
+    router = Cls().get_router(v(1))
+    assert isinstance(router, APIRouter)
+    assert len(router.routes) == 0
+
+
+def test_get_router_other_version():
+    class TestResource(Resource, version=v(1), name="testing"):
+        @get("/foo/{id}")
+        def get_test(self, id: int):
+            return "ok"
+
+    with pytest.raises(AssertionError):
+        TestResource().get_router(v(2))
+
+
+def test_get_router():
+    class TestResource(Resource, version=v(1), name="testing"):
+        @get("/foo/{id}")
+        def get_test(self, id: int):
+            return "ok"
+
+    resource = TestResource()
+
+    router = resource.get_router(v(1))
+
+    assert len(router.routes) == 1
+
+    route = router.routes[0]
+    assert route.path == "/foo/{id}"
+    assert route.operation_id == "get_test"
+    assert route.name == "v1/get_test"
+    assert route.tags == ["testing"]
+    assert route.methods == {"GET"}
+    # 'self' is missing from the parameters
+    assert list(route.param_convertors.keys()) == ["id"]
+
+
+def test_get_openapi_tag():
+    class Cls(Resource, version=v(1), name="foo"):
+        """Docstring"""
+
+    actual = Cls().get_openapi_tag()
+
+    assert actual.name == "foo"
+    assert actual.description == "Docstring"
+
+
+def test_v():
+    assert v(1, "alpha") == APIVersion(version=1, stability=Stability.ALPHA)
+
+
+@pytest.mark.parametrize(
+    "version,expected",
+    [(v(1, "beta"), "v1-beta"), (v(2), "v2"), (v(3, "alpha"), "v3-alpha")],
+)
+def test_api_version_prefix(version, expected):
+    assert version.prefix == expected
+
+
+def test_url_path_for():
+    class TestResource(Resource, version=v(1), name="testing"):
+        @get("/foo/{id}")
+        def get_test(self, id: int):
+            return "ok"
+
+    resource = TestResource()
+    router = resource.get_router(v(1))
+
+    assert router.url_path_for("v1/get_test", id=2) == "/foo/2"
+
+
+def test_with_version():
+    class TestResource(Resource, version=v(1), name="testing"):
+        """Foo"""
+
+        @get("/foo/{id}")
+        def get_test(self, id: int):
+            return "ok"
+
+    resource_cls = TestResource.with_version(v(1, "beta"))
+
+    assert resource_cls.version == v(1, "beta")
+    assert resource_cls.name == "testing"
+    assert resource_cls.__doc__ == "Foo"
+    assert resource_cls.__bases__ == (TestResource,)
+
+
+def test_get_less_stable():
+    class V1(Resource, version=v(1), name="testing"):
+        pass
+
+    class V1Beta(V1, version=v(1, "beta"), name="testing"):
+        pass
+
+    resources = {x.version: x() for x in [V1, V1Beta]}
+    assert resources[v(1)].get_less_stable(resources) is resources[v(1, "beta")]
+
+    v1_alpha = resources[v(1, "beta")].get_less_stable(resources)
+    assert v1_alpha.version == v(1, "alpha")
+    assert v1_alpha.__class__.__bases__ == (V1Beta,)
+
+
+def test_get_less_stable_no_subclass():
+    class V1(Resource, version=v(1), name="testing"):
+        pass
+
+    class V1Beta(Resource, version=v(1, "beta"), name="testing"):
+        pass
+
+    resources = {x.version: x() for x in [V1, V1Beta]}
+
+    with pytest.raises(RuntimeError):
+        resources[v(1)].get_less_stable(resources)

+ 73 - 0
clean_python/tests/test_root_entity.py

@@ -0,0 +1,73 @@
+from datetime import datetime, timezone
+from unittest import mock
+
+import pytest
+
+from base_lib import RootEntity
+
+SOME_DATETIME = datetime(2023, 1, 1, tzinfo=timezone.utc)
+
+
+class User(RootEntity):
+    name: str
+
+
+@pytest.fixture
+def user():
+    return User(
+        id=4,
+        name="jan",
+        created_at=datetime(2010, 1, 1, tzinfo=timezone.utc),
+        updated_at=datetime(2020, 1, 1, tzinfo=timezone.utc),
+    )
+
+
+@pytest.fixture
+def patched_now():
+    with mock.patch("base_lib.root_entity.now", return_value=SOME_DATETIME):
+        yield
+
+
+def test_create(patched_now):
+    obj = User.create(name="piet")
+
+    assert obj.id is None
+    assert obj.name == "piet"
+    assert obj.created_at == SOME_DATETIME
+    assert obj.updated_at == SOME_DATETIME
+
+
+def test_create_with_id():
+    obj = User.create(id=42, name="piet")
+
+    assert obj.id == 42
+
+
+def test_update(user, patched_now):
+    actual = user.update(name="piet")
+
+    assert actual is not user
+    assert actual.name == "piet"
+    assert actual.updated_at == SOME_DATETIME
+    assert actual.created_at == datetime(2010, 1, 1, tzinfo=timezone.utc)
+
+
+def test_update_including_id(user):
+    actual = user.update(id=4, name="piet")
+
+    assert actual is not user
+    assert actual.name == "piet"
+
+
+@pytest.mark.parametrize("new_id", [None, 42, "foo"])
+def test_update_with_wrong_id(user, new_id):
+    with pytest.raises(ValueError):
+        user.update(id=new_id, name="piet")
+
+
+@pytest.mark.parametrize("new_id", [None, 42])
+def test_update_give_id(new_id):
+    user_without_id = User.create(name="jan")
+    actual = user_without_id.update(id=new_id, name="piet")
+
+    assert actual.id == new_id

+ 0 - 10
clean_python/tests/test_scripts.py

@@ -1,10 +0,0 @@
-# -*- coding: utf-8 -*-
-"""Tests for script.py"""
-from clean_python import scripts
-
-
-def test_get_parser():
-    parser = scripts.get_parser()
-    # As a test, we just check one option. That's enough.
-    options = parser.parse_args()
-    assert options.verbose is False

+ 47 - 0
clean_python/tests/test_service.py

@@ -0,0 +1,47 @@
+import pytest
+
+from base_lib import Resource, Service, v
+
+
+class V1Foo(Resource, version=v(1), name="foo"):
+    pass
+
+
+class V1BetaFoo(V1Foo, version=v(1, "beta"), name="foo"):
+    pass
+
+
+class V1AlphaFoo(V1BetaFoo, version=v(1, "alpha"), name="foo"):
+    pass
+
+
+class V2AlphaFoo(Resource, version=v(2, "alpha"), name="foo"):
+    pass
+
+
+@pytest.mark.parametrize(
+    "resource_classes",
+    [
+        (V1AlphaFoo,),
+        (V1BetaFoo, V1AlphaFoo),
+        (V1Foo, V1BetaFoo, V1AlphaFoo),
+        (V1AlphaFoo, V2AlphaFoo),
+    ],
+)
+def test_service_init(resource_classes):
+    resources = [cls() for cls in resource_classes]
+    service = Service(*resources)
+    assert set(service.resources) == set(resources)
+
+
+@pytest.mark.parametrize(
+    "resource_classes,expected_versions",
+    [
+        ((V1BetaFoo,), {v(1, "beta"), v(1, "alpha")}),
+        ((V1Foo,), {v(1), v(1, "beta"), v(1, "alpha")}),
+    ],
+)
+def test_service_init_dynamic_gen(resource_classes, expected_versions):
+    resources = [cls() for cls in resource_classes]
+    service = Service(*resources)
+    assert set(x.version for x in service.resources) == expected_versions

+ 453 - 0
clean_python/tests/test_sql_gateway.py

@@ -0,0 +1,453 @@
+from datetime import datetime, timezone
+from unittest import mock
+
+import pytest
+from sqlalchemy import Column, DateTime, ForeignKey, Integer, MetaData, Table, Text
+
+from base_lib import (
+    assert_query_equal,
+    Conflict,
+    DoesNotExist,
+    FakeSQLDatabase,
+    Filter,
+    PageOptions,
+    SQLGateway,
+)
+
+writer = Table(
+    "writer",
+    MetaData(),
+    Column("id", Integer, primary_key=True, autoincrement=True),
+    Column("value", Text, nullable=False),
+    Column("updated_at", DateTime(timezone=True), nullable=False),
+)
+
+
+book = Table(
+    "book",
+    MetaData(),
+    Column("id", Integer, primary_key=True, autoincrement=True),
+    Column("title", Text, nullable=False),
+    Column(
+        "writer_id",
+        Integer,
+        ForeignKey("writer.id", ondelete="CASCADE", name="book_writer_id_fkey"),
+        nullable=False,
+    ),
+)
+
+
+ALL_FIELDS = "writer.id, writer.value, writer.updated_at"
+BOOK_FIELDS = "book.id, book.title, book.writer_id"
+
+
+class TstSQLGateway(SQLGateway, table=writer):
+    pass
+
+
+class TstRelatedSQLGateway(SQLGateway, table=book):
+    pass
+
+
+@pytest.fixture
+def sql_gateway():
+    return TstSQLGateway(FakeSQLDatabase())
+
+
+@pytest.fixture
+def related_sql_gateway():
+    return TstRelatedSQLGateway(FakeSQLDatabase())
+
+
+@pytest.mark.parametrize(
+    "filters,sql",
+    [
+        ([], ""),
+        ([Filter(field="value", values=[])], " WHERE false"),
+        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
+        (
+            [Filter(field="value", values=["foo", "bar"])],
+            " WHERE writer.value IN ('foo', 'bar')",
+        ),
+        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
+        (
+            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
+            " WHERE writer.id = 1 AND writer.value = 'foo'",
+        ),
+    ],
+)
+async def test_filter(sql_gateway, filters, sql):
+    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
+    assert await sql_gateway.filter(filters) == [{"id": 2, "value": "foo"}]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT {ALL_FIELDS} FROM writer{sql}",
+    )
+
+
+@pytest.mark.parametrize(
+    "page_options,sql",
+    [
+        (None, ""),
+        (
+            PageOptions(limit=5, order_by="id"),
+            " ORDER BY writer.id ASC LIMIT 5 OFFSET 0",
+        ),
+        (
+            PageOptions(limit=5, offset=2, order_by="id", ascending=False),
+            " ORDER BY writer.id DESC LIMIT 5 OFFSET 2",
+        ),
+    ],
+)
+async def test_filter_with_pagination(sql_gateway, page_options, sql):
+    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
+    assert await sql_gateway.filter([], params=page_options) == [
+        {"id": 2, "value": "foo"}
+    ]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT {ALL_FIELDS} FROM writer{sql}",
+    )
+
+
+async def test_filter_with_pagination_and_filter(sql_gateway):
+    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
+    assert await sql_gateway.filter(
+        [Filter(field="value", values=["foo"])],
+        params=PageOptions(limit=5, order_by="id"),
+    ) == [{"id": 2, "value": "foo"}]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"SELECT {ALL_FIELDS} FROM writer "
+            f"WHERE writer.value = 'foo' "
+            f"ORDER BY writer.id ASC LIMIT 5 OFFSET 0"
+        ),
+    )
+
+
+@pytest.mark.parametrize(
+    "filters,sql",
+    [
+        ([], ""),
+        ([Filter(field="value", values=[])], " WHERE false"),
+        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
+        (
+            [Filter(field="value", values=["foo", "bar"])],
+            " WHERE writer.value IN ('foo', 'bar')",
+        ),
+        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
+        (
+            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
+            " WHERE writer.id = 1 AND writer.value = 'foo'",
+        ),
+    ],
+)
+async def test_count(sql_gateway, filters, sql):
+    sql_gateway.provider.result.return_value = [{"count": 4}]
+    assert await sql_gateway.count(filters) == 4
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT count(*) AS count FROM writer{sql}",
+    )
+
+
+@mock.patch.object(SQLGateway, "filter")
+async def test_get(filter_m, sql_gateway):
+    filter_m.return_value = [{"id": 2, "value": "foo"}]
+    assert await sql_gateway.get(2) == filter_m.return_value[0]
+    assert len(sql_gateway.provider.queries) == 0
+    filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
+
+
+@mock.patch.object(SQLGateway, "filter")
+async def test_get_does_not_exist(filter_m, sql_gateway):
+    filter_m.return_value = []
+    assert await sql_gateway.get(2) is None
+    assert len(sql_gateway.provider.queries) == 0
+    filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
+
+
+@pytest.mark.parametrize(
+    "record,sql",
+    [
+        ({}, "DEFAULT VALUES"),
+        ({"value": "foo"}, "(value) VALUES ('foo')"),
+        ({"id": None, "value": "foo"}, "(value) VALUES ('foo')"),
+        ({"id": 2, "value": "foo"}, "(id, value) VALUES (2, 'foo')"),
+        ({"value": "foo", "nonexisting": 2}, "(value) VALUES ('foo')"),
+    ],
+)
+async def test_add(sql_gateway, record, sql):
+    records = [{"id": 2, "value": "foo"}]
+    sql_gateway.provider.result.return_value = records
+    assert await sql_gateway.add(record) == records[0]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (f"INSERT INTO writer {sql} RETURNING {ALL_FIELDS}"),
+    )
+
+
+@pytest.mark.parametrize(
+    "record,if_unmodified_since,sql",
+    [
+        (
+            {"id": 2, "value": "foo"},
+            None,
+            "SET id=2, value='foo' WHERE writer.id = 2",
+        ),
+        ({"id": 2, "other": "foo"}, None, "SET id=2 WHERE writer.id = 2"),
+        (
+            {"id": 2, "value": "foo"},
+            datetime(2010, 1, 1, tzinfo=timezone.utc),
+            (
+                "SET id=2, value='foo' WHERE writer.id = 2 "
+                "AND writer.updated_at = '2010-01-01 00:00:00+00:00'"
+            ),
+        ),
+    ],
+)
+async def test_update(sql_gateway, record, if_unmodified_since, sql):
+    records = [{"id": 2, "value": "foo"}]
+    sql_gateway.provider.result.return_value = records
+    assert await sql_gateway.update(record, if_unmodified_since) == records[0]
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (f"UPDATE writer {sql} RETURNING {ALL_FIELDS}"),
+    )
+
+
+async def test_update_does_not_exist(sql_gateway):
+    sql_gateway.provider.result.return_value = []
+    with pytest.raises(DoesNotExist):
+        await sql_gateway.update({"id": 2})
+    assert len(sql_gateway.provider.queries) == 1
+
+
+@mock.patch.object(SQLGateway, "get")
+async def test_update_if_unmodified_since_does_not_exist(get_m, sql_gateway):
+    get_m.return_value = None
+    sql_gateway.provider.result.return_value = []
+    with pytest.raises(DoesNotExist):
+        await sql_gateway.update(
+            {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
+        )
+    assert len(sql_gateway.provider.queries) == 1
+    get_m.assert_awaited_once_with(2)
+
+
+@mock.patch.object(SQLGateway, "get")
+async def test_update_if_unmodified_since_conflict(get_m, sql_gateway):
+    get_m.return_value = {"id": 2, "value": "foo"}
+    sql_gateway.provider.result.return_value = []
+    with pytest.raises(Conflict):
+        await sql_gateway.update(
+            {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
+        )
+    assert len(sql_gateway.provider.queries) == 1
+    get_m.assert_awaited_once_with(2)
+
+
+async def test_remove(sql_gateway):
+    sql_gateway.provider.result.return_value = [{"id": 2}]
+    assert (await sql_gateway.remove(2)) is True
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
+    )
+
+
+async def test_remove_does_not_exist(sql_gateway):
+    sql_gateway.provider.result.return_value = []
+    assert (await sql_gateway.remove(2)) is False
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
+    )
+
+
+async def test_upsert(sql_gateway):
+    record = {"id": 2, "value": "foo"}
+    sql_gateway.provider.result.return_value = [record]
+    assert await sql_gateway.upsert(record) == record
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        (
+            f"INSERT INTO writer (id, value) VALUES (2, 'foo') "
+            f"ON CONFLICT (id) DO UPDATE SET "
+            f"id = %(param_1)s, value = %(param_2)s "
+            f"RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+@mock.patch.object(SQLGateway, "add")
+async def test_upsert_no_id(add_m, sql_gateway):
+    add_m.return_value = {"id": 5, "value": "foo"}
+    assert await sql_gateway.upsert({"value": "foo"}) == add_m.return_value
+
+    add_m.assert_awaited_once_with({"value": "foo"})
+    assert len(sql_gateway.provider.queries) == 0
+
+
+async def test_get_related_one_to_many(related_sql_gateway: SQLGateway):
+    writers = [{"id": 2}, {"id": 3}]
+    books = [
+        {"id": 3, "title": "x", "writer_id": 2},
+        {"id": 4, "title": "y", "writer_id": 2},
+    ]
+    related_sql_gateway.provider.result.return_value = books
+    await related_sql_gateway._get_related_one_to_many(
+        items=writers,
+        field_name="books",
+        fk_name="writer_id",
+    )
+
+    assert writers == [{"id": 2, "books": books}, {"id": 3, "books": []}]
+    assert len(related_sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        related_sql_gateway.provider.queries[0][0],
+        (
+            "SELECT book.id, book.title, book.writer_id FROM book WHERE book.writer_id IN (2, 3)"
+        ),
+    )
+
+
+@pytest.mark.parametrize(
+    "books,current_books,expected_queries,query_results",
+    [
+        # no change
+        (
+            [{"id": 3, "title": "x", "writer_id": 2}],
+            [{"id": 3, "title": "x", "writer_id": 2}],
+            [],
+            [],
+        ),
+        # added a book (without an id)
+        (
+            [{"title": "x", "writer_id": 2}],
+            [],
+            [
+                f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}"
+            ],
+            [[{"id": 3, "title": "x", "writer_id": 2}]],
+        ),
+        # added a book (with an id)
+        (
+            [{"id": 3, "title": "x", "writer_id": 2}],
+            [],
+            [
+                f"INSERT INTO book (id, title, writer_id) VALUES (3, 'x', 2) RETURNING {BOOK_FIELDS}"
+            ],
+            [[{"id": 3, "title": "x", "writer_id": 2}]],
+        ),
+        # updated a book
+        (
+            [{"id": 3, "title": "x", "writer_id": 2}],
+            [{"id": 3, "title": "a", "writer_id": 2}],
+            [
+                f"UPDATE book SET id=3, title='x', writer_id=2 WHERE book.id = 3 RETURNING {BOOK_FIELDS}"
+            ],
+            [[{"id": 3, "title": "x", "writer_id": 2}]],
+        ),
+        # replaced a book with a new one
+        (
+            [{"title": "x", "writer_id": 2}],
+            [{"id": 15, "title": "a", "writer_id": 2}],
+            [
+                f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}",
+                "DELETE FROM book WHERE book.id = 15 RETURNING book.id",
+            ],
+            [[{"id": 3, "title": "x", "writer_id": 2}], [{"id": 15}]],
+        ),
+    ],
+)
+async def test_set_related_one_to_many(
+    related_sql_gateway: SQLGateway,
+    books,
+    current_books,
+    expected_queries,
+    query_results,
+):
+    writer = {"id": 2, "books": books}
+    related_sql_gateway.provider.result.side_effect = [current_books] + query_results
+    result = writer.copy()
+    await related_sql_gateway._set_related_one_to_many(
+        item=writer,
+        result=result,
+        field_name="books",
+        fk_name="writer_id",
+    )
+
+    assert result == {
+        "id": 2,
+        "books": [{"id": 3, "title": "x", "writer_id": 2}],
+    }
+    assert len(related_sql_gateway.provider.queries) == len(expected_queries) + 1
+    assert_query_equal(
+        related_sql_gateway.provider.queries[0][0],
+        f"SELECT {BOOK_FIELDS} FROM book WHERE book.writer_id = 2",
+    )
+    for (actual_query,), expected_query in zip(
+        related_sql_gateway.provider.queries[1:], expected_queries
+    ):
+        assert_query_equal(actual_query, expected_query)
+
+
+async def test_update_transactional(sql_gateway):
+    existing = {"id": 2, "value": "foo"}
+    expected = {"id": 2, "value": "bar"}
+    sql_gateway.provider.result.side_effect = ([existing], [expected])
+    actual = await sql_gateway.update_transactional(
+        2, lambda x: {"id": x["id"], "value": "bar"}
+    )
+    assert actual == expected
+
+    (queries,) = sql_gateway.provider.queries
+    assert len(queries) == 2
+    assert_query_equal(
+        queries[0],
+        f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 FOR UPDATE",
+    )
+    assert_query_equal(
+        queries[1],
+        (
+            f"UPDATE writer SET id=2, value='bar' WHERE writer.id = 2 RETURNING {ALL_FIELDS}"
+        ),
+    )
+
+
+@pytest.mark.parametrize(
+    "filters,sql",
+    [
+        ([], ""),
+        ([Filter(field="value", values=[])], " WHERE false"),
+        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
+        (
+            [Filter(field="value", values=["foo", "bar"])],
+            " WHERE writer.value IN ('foo', 'bar')",
+        ),
+        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
+        (
+            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
+            " WHERE writer.id = 1 AND writer.value = 'foo'",
+        ),
+    ],
+)
+async def test_exists(sql_gateway, filters, sql):
+    sql_gateway.provider.result.return_value = [{"exists": True}]
+    assert await sql_gateway.exists(filters) is True
+    assert len(sql_gateway.provider.queries) == 1
+    assert_query_equal(
+        sql_gateway.provider.queries[0][0],
+        f"SELECT true AS exists FROM writer{sql} LIMIT 1",
+    )

+ 65 - 0
clean_python/tests/test_value_object.py

@@ -0,0 +1,65 @@
+import pytest
+from pydantic import ValidationError, validator
+
+from base_lib import BadRequest, ValueObject
+
+
+class Color(ValueObject):
+    name: str
+
+    @validator("name")
+    def name_not_empty(cls, v):
+        assert v != ""
+        return v
+
+
+@pytest.fixture
+def color():
+    return Color(name="green")
+
+
+def test_validator():
+    with pytest.raises(ValidationError) as e:
+        Color(name="")
+
+    assert e.type is ValidationError  # not BadRequest
+
+
+def test_create_err():
+    with pytest.raises(BadRequest):
+        Color.create(name="")
+
+
+def test_update(color):
+    updated = color.update(name="red")
+
+    assert color.name == "green"
+    assert updated.name == "red"
+
+
+def test_update_validates(color):
+    with pytest.raises(BadRequest):
+        color.update(name="")
+
+
+def test_run_validation(color):
+    assert color.run_validation() == color
+
+
+def test_run_validation_err():
+    color = Color.construct(name="")
+
+    with pytest.raises(BadRequest):
+        color.run_validation()
+
+
+def test_hashable(color):
+    assert len({color, color}) == 1
+
+
+def test_eq(color):
+    assert color == color
+
+
+def test_neq(color):
+    assert color != Color(name="red")

+ 12 - 0
clean_python/tmpdir_provider.py

@@ -0,0 +1,12 @@
+from tempfile import TemporaryDirectory
+from typing import Optional
+
+__all__ = ["TmpDirProvider"]
+
+
+class TmpDirProvider:
+    def __init__(self, dir: Optional[str] = None):
+        self.dir = dir
+
+    def __call__(self) -> TemporaryDirectory:
+        return TemporaryDirectory(dir=self.dir)

+ 8 - 0
clean_python/value.py

@@ -0,0 +1,8 @@
+class Value:
+    @classmethod
+    def __get_validators__(cls):
+        yield cls.validate
+
+    @classmethod
+    def validate(cls, v):
+        return cls(v)  # type: ignore

+ 46 - 0
clean_python/value_object.py

@@ -0,0 +1,46 @@
+from typing import Optional, Type, TypeVar
+
+from pydantic import BaseModel, ValidationError
+
+from .exceptions import BadRequest
+
+T = TypeVar("T", bound="ValueObject")
+
+
+class ValueObject(BaseModel):
+    class Config:
+        allow_mutation = False
+
+    def run_validation(self: T) -> T:
+        try:
+            return self.__class__(**self.dict())
+        except ValidationError as e:
+            raise BadRequest(e)
+
+    @classmethod
+    def create(cls: Type[T], **values) -> T:
+        try:
+            return cls(**values)
+        except ValidationError as e:
+            raise BadRequest(e)
+
+    def update(self: T, **values) -> T:
+        try:
+            return self.__class__(**{**self.dict(), **values})
+        except ValidationError as e:
+            raise BadRequest(e)
+
+    def __hash__(self):
+        return hash(self.__class__) + hash(tuple(self.__dict__.values()))
+
+
+K = TypeVar("K", bound="ValueObjectWithId")
+
+
+class ValueObjectWithId(ValueObject):
+    id: Optional[int] = None
+
+    def update(self: K, **values) -> K:
+        if "id" in values and self.id is not None and values["id"] != self.id:
+            raise ValueError("Cannot change the id")
+        return super().update(**values)

+ 1 - 1
pyproject.toml

@@ -10,7 +10,7 @@ license = {text = "MIT"}
 classifiers = ["Programming Language :: Python"]
 keywords = []
 requires-python = ">=3.7"
-dependencies = []
+dependencies = ["pydantic==1.*"]
 dynamic = ["version"]
 
 [project.optional-dependencies]