Browse Source

Fix BaseTask for retried tasks (#33)

Casper van der Wel 1 year ago
parent
commit
bb186cd387
3 changed files with 87 additions and 6 deletions
  1. 7 1
      CHANGES.md
  2. 10 5
      clean_python/celery/base_task.py
  3. 70 0
      tests/celery/test_celery_base_task.py

+ 7 - 1
CHANGES.md

@@ -1,7 +1,13 @@
 # Changelog of clean-python
 
 
-0.7.2 (unreleased)
+0.8.1 (unreleased)
+------------------
+
+- Fixed celery BaseTask for retried tasks.
+
+
+0.8.0 (2023-11-06)
 ------------------
 
 - Renamed clean_python.celery to clean_python.amqp; clean_python.celery now contains

+ 10 - 5
clean_python/celery/base_task.py

@@ -30,13 +30,18 @@ class TaskHeaders(ValueObject):
 class BaseTask(Task):
     def apply_async(self, args=None, kwargs=None, **options):
         # include correlation_id and tenant in the headers
-        if options.get("headers") is not None:
-            headers = options["headers"].copy()
+        if "headers" in options:
+            headers = options.pop("headers")
+            if headers is None:
+                headers = {}
+            else:
+                headers = headers.copy()
         else:
             headers = {}
-        headers[HEADER_FIELD] = TaskHeaders(
-            tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
-        ).model_dump(mode="json")
+        if HEADER_FIELD not in headers:
+            headers[HEADER_FIELD] = TaskHeaders(
+                tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
+            ).model_dump(mode="json")
         return super().apply_async(args, kwargs, headers=headers, **options)
 
     def __call__(self, *args, **kwargs):

+ 70 - 0
tests/celery/test_celery_base_task.py

@@ -0,0 +1,70 @@
+from unittest import mock
+from uuid import UUID
+from uuid import uuid4
+
+import pytest
+from celery import Task
+
+from clean_python import ctx
+from clean_python import Tenant
+from clean_python.celery import BaseTask
+from clean_python.celery.base_task import HEADER_FIELD
+
+
+@pytest.fixture
+def mocked_apply_async():
+    with mock.patch.object(Task, "apply_async") as m:
+        yield m
+
+
+@pytest.fixture
+def temp_context():
+    ctx.tenant = Tenant(id=2, name="test")
+    ctx.correlation_id = uuid4()
+    yield ctx
+    ctx.tenant = None
+    ctx.correlation_id = None
+
+
+def test_apply_async(mocked_apply_async):
+    BaseTask().apply_async(args="foo", kwargs="bar")
+
+    assert mocked_apply_async.call_count == 1
+    args, kwargs = mocked_apply_async.call_args
+    assert args == ("foo", "bar")
+    assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
+    UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"])  # generated
+
+
+def test_apply_async_with_context(mocked_apply_async, temp_context):
+    BaseTask().apply_async(args="foo", kwargs="bar")
+
+    assert mocked_apply_async.call_count == 1
+    _, kwargs = mocked_apply_async.call_args
+    assert kwargs["headers"][HEADER_FIELD]["tenant"] == temp_context.tenant.model_dump(
+        mode="json"
+    )
+    kwargs["headers"][HEADER_FIELD]["correlation_id"] == str(
+        temp_context.correlation_id
+    )
+
+
+def test_apply_async_headers_extended(mocked_apply_async):
+    headers = {"baz": 2}
+    BaseTask().apply_async(args="foo", kwargs="bar", headers=headers)
+
+    assert mocked_apply_async.call_count == 1
+    _, kwargs = mocked_apply_async.call_args
+    assert kwargs["headers"]["baz"] == 2
+    assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
+    UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"])  # generated
+
+    assert headers == {"baz": 2}  # not changed inplace
+
+
+def test_apply_async_headers_already_present(mocked_apply_async):
+    BaseTask().apply_async(args="foo", kwargs="bar", headers={HEADER_FIELD: "foo"})
+
+    assert mocked_apply_async.call_count == 1
+    _, kwargs = mocked_apply_async.call_args
+    assert kwargs["headers"] == {HEADER_FIELD: "foo"}