Kaynağa Gözat

Fix celery header issues (#35)

Casper van der Wel 1 yıl önce
ebeveyn
işleme
a60940c86b

+ 2 - 0
CHANGES.md

@@ -8,6 +8,8 @@
 
 - Fix access logging of correlation id.
 
+- Workaround celery issues with message headers: use the body (kwargs) instead.
+
 
 0.8.1 (2023-11-06)
 ------------------

+ 17 - 19
clean_python/celery/base_task.py

@@ -1,11 +1,13 @@
 from contextvars import copy_context
 from typing import Optional
+from typing import Tuple
 from uuid import UUID
 from uuid import uuid4
 
 from celery import Task
 
 from clean_python import ctx
+from clean_python import Json
 from clean_python import Tenant
 from clean_python import ValueObject
 
@@ -20,35 +22,31 @@ class TaskHeaders(ValueObject):
     correlation_id: Optional[UUID]
 
     @classmethod
-    def from_celery_request(cls, request) -> "TaskHeaders":
-        if request.headers and HEADER_FIELD in request.headers:
-            return TaskHeaders(**request.headers[HEADER_FIELD])
+    def from_kwargs(cls, kwargs: Json) -> Tuple["TaskHeaders", Json]:
+        if HEADER_FIELD in kwargs:
+            kwargs = kwargs.copy()
+            headers = kwargs.pop(HEADER_FIELD)
+            return TaskHeaders(**headers), kwargs
         else:
-            return TaskHeaders(tenant=None, correlation_id=None)
+            return TaskHeaders(tenant=None, correlation_id=None), kwargs
 
 
 class BaseTask(Task):
     def apply_async(self, args=None, kwargs=None, **options):
-        # include correlation_id and tenant in the headers
-        if "headers" in options:
-            headers = options.pop("headers")
-            if headers is None:
-                headers = {}
-            else:
-                headers = headers.copy()
-        else:
-            headers = {}
-        if HEADER_FIELD not in headers:
-            headers[HEADER_FIELD] = TaskHeaders(
-                tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
-            ).model_dump(mode="json")
-        return super().apply_async(args, kwargs, headers=headers, **options)
+        # include correlation_id and tenant in the kwargs
+        # and NOT the headers as that is buggy in celery
+        # see  https://github.com/celery/celery/issues/4875
+        kwargs = {} if kwargs is None else kwargs.copy()
+        kwargs[HEADER_FIELD] = TaskHeaders(
+            tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
+        ).model_dump(mode="json")
+        return super().apply_async(args, kwargs, **options)
 
     def __call__(self, *args, **kwargs):
         return copy_context().run(self._call_with_context, *args, **kwargs)
 
     def _call_with_context(self, *args, **kwargs):
-        headers = TaskHeaders.from_celery_request(self.request)
+        headers, kwargs = TaskHeaders.from_kwargs(kwargs)
         ctx.tenant = headers.tenant
         ctx.correlation_id = headers.correlation_id
         return super().__call__(*args, **kwargs)

+ 28 - 4
clean_python/celery/celery_task_logger.py

@@ -63,11 +63,34 @@ class CeleryTaskLogger:
 
         try:
             request = task.request
-            correlation_id = TaskHeaders.from_celery_request(request).correlation_id
         except AttributeError:
             request = None
+
+        try:
+            headers, kwargs = TaskHeaders.from_kwargs(request.kwargs)
+        except AttributeError:
+            headers = kwargs = None  # type: ignore
+
+        try:
+            tenant_id = headers.tenant.id  # type: ignore
+        except AttributeError:
+            tenant_id = None
+
+        try:
+            correlation_id = headers.correlation_id
+        except AttributeError:
             correlation_id = None
 
+        try:
+            args = json.loads(json.dumps(request.args))
+        except (AttributeError, TypeError):
+            args = None
+
+        try:
+            kwargs = json.loads(json.dumps(kwargs))
+        except TypeError:
+            kwargs = None
+
         log_dict = {
             "tag_suffix": "task_log",
             "time": start_time,
@@ -77,10 +100,11 @@ class CeleryTaskLogger:
             "duration": duration,
             "origin": getattr(request, "origin", None),
             "retries": getattr(request, "retries", None),
-            "argsrepr": getattr(request, "argsrepr", None),
-            "kwargsrepr": getattr(request, "kwargsrepr", None),
+            "args": args,
+            "kwargs": kwargs,
             "result": result_json,
-            "correlation_id": str(correlation_id) if correlation_id else None,
+            "tenant_id": tenant_id,
+            "correlation_id": None if correlation_id is None else str(correlation_id),
         }
 
         return self.gateway.add(log_dict)

+ 11 - 8
integration_tests/test_int_celery.py

@@ -15,7 +15,7 @@ from clean_python.celery import set_task_logger
 
 @pytest.fixture(scope="session")
 def celery_parameters():
-    return {"task_cls": BaseTask}
+    return {"task_cls": BaseTask, "strict_typing": False}
 
 
 @pytest.fixture(scope="session")
@@ -41,7 +41,10 @@ def celery_task(celery_app, celery_worker):
         elif event == "retry":
             raise self.retry(countdown=seconds, max_retries=1)
         elif event == "context":
-            return {"tenant": ctx.tenant.id, "correlation_id": str(ctx.correlation_id)}
+            return {
+                "tenant_id": ctx.tenant.id,
+                "correlation_id": str(ctx.correlation_id),
+            }
         else:
             raise ValueError(f"Unknown event '{event}'")
 
@@ -71,11 +74,12 @@ def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger):
     assert log["state"] == "SUCCESS"
     assert log["name"] == "testing"
     assert log["duration"] > 0.0
-    assert log["argsrepr"] == "(0.0,)"
-    assert log["kwargsrepr"] == "{'return_value': 16}"
+    assert log["args"] == [0.0]
+    assert log["kwargs"] == {"return_value": 16}
     assert log["retries"] == 0
     assert log["result"] == {"value": 16}
     assert UUID(log["correlation_id"])  # generated
+    assert log["tenant_id"] is None
 
 
 def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger):
@@ -98,15 +102,14 @@ def custom_context():
     ctx.tenant = None
 
 
-def test_context(celery_task: BaseTask, task_logger: CeleryTaskLogger, custom_context):
+def test_context(celery_task: BaseTask, custom_context, task_logger):
     result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0)
-    custom_context.correlation_id = None
-    custom_context.tenant = None
 
     assert result.get(timeout=10) == {
-        "tenant": 2,
+        "tenant_id": 2,
         "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
     }
 
     (log,) = task_logger.gateway.filter([])
     assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"
+    assert log["tenant_id"] == 2

+ 18 - 35
tests/celery/test_celery_base_task.py

@@ -26,45 +26,28 @@ def temp_context():
     ctx.correlation_id = None
 
 
-def test_apply_async(mocked_apply_async):
-    BaseTask().apply_async(args="foo", kwargs="bar")
+@mock.patch(
+    "clean_python.celery.base_task.uuid4",
+    return_value=UUID("479156af-a302-48fc-89ed-8c426abadc4c"),
+)
+def test_apply_async(uuid4, mocked_apply_async):
+    BaseTask().apply_async(args=("foo",), kwargs={"a": "bar"})
 
     assert mocked_apply_async.call_count == 1
-    args, kwargs = mocked_apply_async.call_args
-    assert args == ("foo", "bar")
-    assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
-    UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"])  # generated
+    (args, kwargs), _ = mocked_apply_async.call_args
+    assert args == ("foo",)
+    assert kwargs["a"] == "bar"
+    assert kwargs[HEADER_FIELD] == {
+        "tenant": None,
+        "correlation_id": "479156af-a302-48fc-89ed-8c426abadc4c",
+    }
 
 
 def test_apply_async_with_context(mocked_apply_async, temp_context):
-    BaseTask().apply_async(args="foo", kwargs="bar")
+    BaseTask().apply_async(args=("foo",), kwargs={"a": "bar"})
 
     assert mocked_apply_async.call_count == 1
-    _, kwargs = mocked_apply_async.call_args
-    assert kwargs["headers"][HEADER_FIELD]["tenant"] == temp_context.tenant.model_dump(
-        mode="json"
-    )
-    kwargs["headers"][HEADER_FIELD]["correlation_id"] == str(
-        temp_context.correlation_id
-    )
-
-
-def test_apply_async_headers_extended(mocked_apply_async):
-    headers = {"baz": 2}
-    BaseTask().apply_async(args="foo", kwargs="bar", headers=headers)
-
-    assert mocked_apply_async.call_count == 1
-    _, kwargs = mocked_apply_async.call_args
-    assert kwargs["headers"]["baz"] == 2
-    assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
-    UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"])  # generated
-
-    assert headers == {"baz": 2}  # not changed inplace
-
-
-def test_apply_async_headers_already_present(mocked_apply_async):
-    BaseTask().apply_async(args="foo", kwargs="bar", headers={HEADER_FIELD: "foo"})
-
-    assert mocked_apply_async.call_count == 1
-    _, kwargs = mocked_apply_async.call_args
-    assert kwargs["headers"] == {HEADER_FIELD: "foo"}
+    (_, kwargs), _ = mocked_apply_async.call_args
+    assert kwargs["a"] == "bar"
+    assert kwargs[HEADER_FIELD]["tenant"] == temp_context.tenant.model_dump(mode="json")
+    assert kwargs[HEADER_FIELD]["correlation_id"] == str(temp_context.correlation_id)

+ 7 - 7
tests/celery/test_celery_task_logger.py

@@ -24,10 +24,11 @@ def test_log_minimal(celery_task_logger: CeleryTaskLogger):
         "state": "STAAT",
         "duration": None,
         "origin": None,
-        "argsrepr": None,
-        "kwargsrepr": None,
+        "args": None,
+        "kwargs": None,
         "result": None,
         "time": None,
+        "tenant_id": None,
         "correlation_id": None,
         "retries": None,
     }
@@ -52,9 +53,8 @@ def celery_task():
     request.id = "abc123"
     request.origin = "hostname"
     request.retries = 25
-    request.argsrepr = "[1, 2]"
-    request.kwargsrepr = "{}"
-    request.headers = {
+    request.args = [1, 2]
+    request.kwargs = {
         "clean_python_context": {
             "tenant": None,
             "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
@@ -73,8 +73,8 @@ def test_log_with_request(celery_task_logger: CeleryTaskLogger, celery_task):
     assert entry["name"] == "task_name"
     assert entry["task_id"] == "abc123"
     assert entry["retries"] == 25
-    assert entry["argsrepr"] == "[1, 2]"
-    assert entry["kwargsrepr"] == "{}"
+    assert entry["args"] == [1, 2]
+    assert entry["kwargs"] == {}
     assert entry["origin"] == "hostname"
     assert entry["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"