Browse Source

Add celery logger (#32)

Casper van der Wel 1 year ago
parent
commit
3053a7030e

+ 1 - 1
.github/workflows/test.yml

@@ -56,7 +56,7 @@ jobs:
       - name: Install python dependencies
         run: |
           pip install --disable-pip-version-check --upgrade pip setuptools
-          pip install -e .[dramatiq,fastapi,auth,celery,fluentbit,sql,s3,api_client,test] ${{ matrix.pins }}
+          pip install -e .[dramatiq,fastapi,auth,celery,fluentbit,sql,s3,api_client,amqp,test] ${{ matrix.pins }}
           pip list
 
       - name: Run tests

+ 2 - 1
CHANGES.md

@@ -4,7 +4,8 @@
 0.7.2 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Renamed clean_python.celery to clean_python.amqp; clean_python.celery now contains
+  actual Celery abstractions.
 
 
 0.7.1 (2023-11-01)

+ 1 - 0
clean_python/amqp/__init__.py

@@ -0,0 +1 @@
+from .celery_rmq_broker import *  # NOQA

+ 0 - 0
clean_python/celery/celery_rmq_broker.py → clean_python/amqp/celery_rmq_broker.py


+ 3 - 1
clean_python/celery/__init__.py

@@ -1 +1,3 @@
-from .celery_rmq_broker import *  # NOQA
+from .base_task import *  # NOQA
+from .celery_task_logger import *  # NOQA
+from .kubernetes import *  # NOQA

+ 49 - 0
clean_python/celery/base_task.py

@@ -0,0 +1,49 @@
+from contextvars import copy_context
+from typing import Optional
+from uuid import UUID
+from uuid import uuid4
+
+from celery import Task
+
+from clean_python import ctx
+from clean_python import Tenant
+from clean_python import ValueObject
+
+__all__ = ["BaseTask"]
+
+
+HEADER_FIELD = "clean_python_context"
+
+
+class TaskHeaders(ValueObject):
+    tenant: Optional[Tenant]
+    correlation_id: Optional[UUID]
+
+    @classmethod
+    def from_celery_request(cls, request) -> "TaskHeaders":
+        if request.headers and HEADER_FIELD in request.headers:
+            return TaskHeaders(**request.headers[HEADER_FIELD])
+        else:
+            return TaskHeaders(tenant=None, correlation_id=None)
+
+
+class BaseTask(Task):
+    def apply_async(self, args=None, kwargs=None, **options):
+        # include correlation_id and tenant in the headers
+        if options.get("headers") is not None:
+            headers = options["headers"].copy()
+        else:
+            headers = {}
+        headers[HEADER_FIELD] = TaskHeaders(
+            tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
+        ).model_dump(mode="json")
+        return super().apply_async(args, kwargs, headers=headers, **options)
+
+    def __call__(self, *args, **kwargs):
+        return copy_context().run(self._call_with_context, *args, **kwargs)
+
+    def _call_with_context(self, *args, **kwargs):
+        headers = TaskHeaders.from_celery_request(self.request)
+        ctx.tenant = headers.tenant
+        ctx.correlation_id = headers.correlation_id
+        return super().__call__(*args, **kwargs)

+ 141 - 0
clean_python/celery/celery_task_logger.py

@@ -0,0 +1,141 @@
+# (c) Nelen & Schuurmans
+
+import json
+import threading
+import time
+from typing import Any
+from typing import Optional
+
+import inject
+from billiard.einfo import ExceptionInfo
+from celery import Task
+from celery.signals import task_failure
+from celery.signals import task_postrun
+from celery.signals import task_prerun
+from celery.signals import task_revoked
+from celery.signals import task_success
+from celery.states import FAILURE
+from celery.states import RETRY
+from celery.states import REVOKED
+from celery.states import SUCCESS
+
+from clean_python import SyncGateway
+from clean_python.fluentbit import SyncFluentbitGateway
+
+from .base_task import TaskHeaders
+
+__all__ = ["CeleryTaskLogger", "set_task_logger"]
+
+
+class CeleryTaskLogger:
+    local = threading.local()
+
+    def __init__(self, gateway_override: Optional[SyncGateway] = None):
+        self.gateway_override = gateway_override
+
+    @property
+    def gateway(self) -> SyncGateway:
+        return self.gateway_override or inject.instance(SyncFluentbitGateway)
+
+    def start(self):
+        self.local.start_time = time.time()
+
+    def stop(self, task: Task, state: str, result: Any = None):
+        # format the result into a dict (elasticsearch will error otherwise)
+        if result is not None and not isinstance(result, dict):
+            result = {"result": result}
+        try:
+            result_json = json.loads(json.dumps(result))
+        except TypeError:
+            result_json = None
+
+        try:
+            start_time = self.local.start_time
+        except AttributeError:
+            start_time = None
+
+        self.local.start_time = None
+
+        if start_time is not None:
+            duration = time.time() - start_time
+        else:
+            duration = None
+
+        try:
+            request = task.request
+            correlation_id = TaskHeaders.from_celery_request(request).correlation_id
+        except AttributeError:
+            request = None
+            correlation_id = None
+
+        log_dict = {
+            "tag_suffix": "task_log",
+            "time": start_time,
+            "task_id": getattr(request, "id", None),
+            "name": task.name,
+            "state": state,
+            "duration": duration,
+            "origin": getattr(request, "origin", None),
+            "retries": getattr(request, "retries", None),
+            "argsrepr": getattr(request, "argsrepr", None),
+            "kwargsrepr": getattr(request, "kwargsrepr", None),
+            "result": result_json,
+            "correlation_id": str(correlation_id) if correlation_id else None,
+        }
+
+        return self.gateway.add(log_dict)
+
+
+celery_logger: Optional[CeleryTaskLogger] = None
+
+
+def set_task_logger(logger: Optional[CeleryTaskLogger]):
+    global celery_logger
+    celery_logger = logger
+
+
+@task_prerun.connect
+def task_prerun_log(**kwargs):
+    if celery_logger is None:
+        return
+    celery_logger.start()
+
+
+@task_postrun.connect
+def task_postrun_log(sender: Task, state: str, **kwargs):
+    if celery_logger is None:
+        return
+    if state not in {None, SUCCESS, FAILURE, RETRY}:
+        celery_logger.stop(task=sender, state=state)
+
+
+@task_success.connect
+def task_success_log(sender: Task, result: Any, **kwargs):
+    if celery_logger is None:
+        return
+    celery_logger.stop(task=sender, state=SUCCESS, result=result)
+
+
+@task_failure.connect
+def task_failure_log(sender: Task, einfo: ExceptionInfo, **kwargs):
+    if celery_logger is None:
+        return
+    celery_logger.stop(
+        task=sender, state=FAILURE, result={"traceback": einfo.traceback}
+    )
+
+
+@task_revoked.connect(dispatch_uid="task_revoked_log")
+def task_revoked_log(sender: Task, **kwargs):
+    if celery_logger is None:
+        return
+    if str(kwargs["signum"]) == "Signals.SIGTERM":
+        # This to filter out duplicate logging on task termination.
+        return
+    if kwargs["terminated"]:
+        state = "TERMINATED"
+    elif kwargs["expired"]:
+        state = "EXPIRED"
+    else:
+        state = REVOKED
+    celery_logger.stop(task=sender, state=state)

+ 48 - 0
clean_python/celery/kubernetes.py

@@ -0,0 +1,48 @@
+from pathlib import Path
+
+from celery import bootsteps
+from celery.signals import worker_ready
+from celery.signals import worker_shutdown
+
+__all__ = ["setup_kubernetes_probes"]
+
+
+HEARTBEAT_FILE = Path("/dev/shm/worker_heartbeat")
+READINESS_FILE = Path("/dev/shm/worker_ready")
+
+
+def register_readiness(**_):
+    READINESS_FILE.touch()
+
+
+def unregister_readiness(**_):
+    READINESS_FILE.unlink(missing_ok=True)
+
+
+class LivenessProbe(bootsteps.StartStopStep):
+    requires = {"celery.worker.components:Timer"}
+
+    def __init__(self, worker, **kwargs):
+        self.requests = []
+        self.tref = None
+
+    def start(self, worker):
+        self.tref = worker.timer.call_repeatedly(
+            1.0,
+            self.update_heartbeat_file,
+            (worker,),
+            priority=10,
+        )
+
+    def stop(self, worker):
+        HEARTBEAT_FILE.unlink(missing_ok=True)
+
+    def update_heartbeat_file(self, worker):
+        HEARTBEAT_FILE.touch()
+
+
+def setup_kubernetes_probes(app):
+    worker_ready.connect(register_readiness)
+    worker_shutdown.connect(unregister_readiness)
+
+    app.steps["worker"].add(LivenessProbe)

+ 112 - 0
integration_tests/test_int_celery.py

@@ -0,0 +1,112 @@
+import time
+from uuid import UUID
+
+import pytest
+from celery.exceptions import Ignore
+from celery.exceptions import Reject
+
+from clean_python import ctx
+from clean_python import InMemorySyncGateway
+from clean_python import Tenant
+from clean_python.celery import BaseTask
+from clean_python.celery import CeleryTaskLogger
+from clean_python.celery import set_task_logger
+
+
+@pytest.fixture(scope="session")
+def celery_parameters():
+    return {"task_cls": BaseTask}
+
+
+@pytest.fixture(scope="session")
+def celery_worker_parameters():
+    return {"shutdown_timeout": 10}
+
+
+@pytest.fixture
+def celery_task(celery_app, celery_worker):
+    @celery_app.task(bind=True, base=BaseTask, name="testing")
+    def sleep_task(self: BaseTask, seconds: float, return_value=None, event="success"):
+        event = event.lower()
+        if event == "success":
+            time.sleep(int(seconds))
+        elif event == "crash":
+            import ctypes
+
+            ctypes.string_at(0)  # segfault
+        elif event == "ignore":
+            raise Ignore()
+        elif event == "reject":
+            raise Reject()
+        elif event == "retry":
+            raise self.retry(countdown=seconds, max_retries=1)
+        elif event == "context":
+            return {"tenant": ctx.tenant.id, "correlation_id": str(ctx.correlation_id)}
+        else:
+            raise ValueError(f"Unknown event '{event}'")
+
+        return {"value": return_value}
+
+    celery_worker.reload()
+    return sleep_task
+
+
+@pytest.fixture
+def task_logger():
+    logger = CeleryTaskLogger(InMemorySyncGateway([]))
+    set_task_logger(logger)
+    yield logger
+    set_task_logger(None)
+
+
+def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger):
+    result = celery_task.delay(0.0, return_value=16)
+
+    assert result.get(timeout=10) == {"value": 16}
+
+    (log,) = task_logger.gateway.filter([])
+    assert 0.0 < (time.time() - log["time"]) < 1.0
+    assert log["tag_suffix"] == "task_log"
+    assert log["task_id"] == result.id
+    assert log["state"] == "SUCCESS"
+    assert log["name"] == "testing"
+    assert log["duration"] > 0.0
+    assert log["argsrepr"] == "(0.0,)"
+    assert log["kwargsrepr"] == "{'return_value': 16}"
+    assert log["retries"] == 0
+    assert log["result"] == {"value": 16}
+    assert UUID(log["correlation_id"])  # generated
+
+
+def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger):
+    result = celery_task.delay(0.0, event="failure")
+
+    with pytest.raises(ValueError):
+        assert result.get(timeout=10)
+
+    (log,) = task_logger.gateway.filter([])
+    assert log["state"] == "FAILURE"
+    assert log["result"]["traceback"].startswith("Traceback")
+
+
+@pytest.fixture
+def custom_context():
+    ctx.correlation_id = UUID("b3089ea7-2585-43e5-a63c-ae30a6e9b5e4")
+    ctx.tenant = Tenant(id=2, name="custom")
+    yield ctx
+    ctx.correlation_id = None
+    ctx.tenant = None
+
+
+def test_context(celery_task: BaseTask, task_logger: CeleryTaskLogger, custom_context):
+    result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0)
+    custom_context.correlation_id = None
+    custom_context.tenant = None
+
+    assert result.get(timeout=10) == {
+        "tenant": 2,
+        "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
+    }
+
+    (log,) = task_logger.gateway.filter([])
+    assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"

+ 4 - 2
pyproject.toml

@@ -21,12 +21,14 @@ test = [
     "debugpy",
     "httpx",
     "uvicorn",
-    "python-multipart"
+    "python-multipart",
+    "pytest-celery"
 ]
 dramatiq = ["dramatiq"]
 fastapi = ["fastapi"]
 auth = ["pyjwt[crypto]==2.6.0"]
-celery = ["pika"]
+amqp = ["pika"]
+celery = ["celery"]
 fluentbit = ["fluent-logger"]
 sql = ["sqlalchemy==2.*", "asyncpg"]
 s3 = ["aioboto3", "boto3"]

+ 2 - 2
tests/test_celery_rmq_broker.py → tests/amqp/test_celery_rmq_broker.py

@@ -2,7 +2,7 @@ from unittest import mock
 
 import pytest
 
-from clean_python.celery import CeleryRmqBroker
+from clean_python.amqp import CeleryRmqBroker
 
 
 @pytest.fixture
@@ -10,7 +10,7 @@ def celery_rmq_broker():
     return CeleryRmqBroker("amqp://rmq:1234//", "some_queue", "host", False)
 
 
-@mock.patch("clean_python.celery.celery_rmq_broker.pika.BlockingConnection")
+@mock.patch("clean_python.amqp.celery_rmq_broker.pika.BlockingConnection")
 async def test_celery_rmq_broker(connection, celery_rmq_broker):
     await celery_rmq_broker.add({"task": "some.task", "args": ["foo", 15]})
 

+ 97 - 0
tests/celery/test_celery_task_logger.py

@@ -0,0 +1,97 @@
+from unittest import mock
+from uuid import uuid4
+
+import pytest
+from celery import Task
+
+from clean_python import InMemorySyncGateway
+from clean_python.celery import CeleryTaskLogger
+
+
+@pytest.fixture
+def celery_task_logger() -> CeleryTaskLogger:
+    return CeleryTaskLogger(InMemorySyncGateway([]))
+
+
+def test_log_minimal(celery_task_logger: CeleryTaskLogger):
+    celery_task_logger.stop(Task(), "STAAT")
+    (entry,) = celery_task_logger.gateway.filter([])
+    assert entry == {
+        "id": 1,
+        "tag_suffix": "task_log",
+        "task_id": None,
+        "name": None,
+        "state": "STAAT",
+        "duration": None,
+        "origin": None,
+        "argsrepr": None,
+        "kwargsrepr": None,
+        "result": None,
+        "time": None,
+        "correlation_id": None,
+        "retries": None,
+    }
+
+
+def test_log_with_duration(celery_task_logger: CeleryTaskLogger):
+    with mock.patch("time.time", return_value=1.0):
+        celery_task_logger.start()
+
+    with mock.patch("time.time", return_value=100.0):
+        celery_task_logger.stop(Task(), "STAAT")
+
+    (entry,) = celery_task_logger.gateway.filter([])
+    assert entry["time"] == 1.0
+    assert entry["duration"] == 99.0
+
+
+@pytest.fixture
+def celery_task():
+    # it seems impossible to instantiate a true celery Task object...
+    request = mock.Mock()
+    request.id = "abc123"
+    request.origin = "hostname"
+    request.retries = 25
+    request.argsrepr = "[1, 2]"
+    request.kwargsrepr = "{}"
+    request.headers = {
+        "clean_python_context": {
+            "tenant": None,
+            "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
+        }
+    }
+    task = mock.Mock()
+    task.name = "task_name"
+    task.request = request
+    return task
+
+
+def test_log_with_request(celery_task_logger: CeleryTaskLogger, celery_task):
+    celery_task_logger.stop(celery_task, "STAAT")
+
+    (entry,) = celery_task_logger.gateway.filter([])
+    assert entry["name"] == "task_name"
+    assert entry["task_id"] == "abc123"
+    assert entry["retries"] == 25
+    assert entry["argsrepr"] == "[1, 2]"
+    assert entry["kwargsrepr"] == "{}"
+    assert entry["origin"] == "hostname"
+    assert entry["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"
+
+
+@pytest.mark.parametrize(
+    "result,expected",
+    [
+        ({"a": "b"}, {"a": "b"}),
+        ("str", {"result": "str"}),  # str to dict
+        ([1], {"result": [1]}),  # list to dict
+        ({"a": uuid4()}, None),  # not-json-serializable
+    ],
+)
+def test_log_with_result(
+    celery_task_logger: CeleryTaskLogger, celery_task, result, expected
+):
+    celery_task_logger.stop(celery_task, "STAAT", result=result)
+
+    (entry,) = celery_task_logger.gateway.filter([])
+    assert entry["result"] == expected

+ 1 - 1
tests/test_context.py

@@ -18,7 +18,7 @@ def test_default_context():
 async def test_task_isolation():
     async def get_set(user):
         ctx.user = user
-        asyncio.sleep(0.01)
+        await asyncio.sleep(0.01)
         assert ctx.user == user
 
     await asyncio.gather(*[get_set(User(id=str(i), name="piet")) for i in range(10)])