|
@@ -1,11 +1,13 @@
|
|
from contextvars import copy_context
|
|
from contextvars import copy_context
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
+from typing import Tuple
|
|
from uuid import UUID
|
|
from uuid import UUID
|
|
from uuid import uuid4
|
|
from uuid import uuid4
|
|
|
|
|
|
from celery import Task
|
|
from celery import Task
|
|
|
|
|
|
from clean_python import ctx
|
|
from clean_python import ctx
|
|
|
|
+from clean_python import Json
|
|
from clean_python import Tenant
|
|
from clean_python import Tenant
|
|
from clean_python import ValueObject
|
|
from clean_python import ValueObject
|
|
|
|
|
|
@@ -20,35 +22,31 @@ class TaskHeaders(ValueObject):
|
|
correlation_id: Optional[UUID]
|
|
correlation_id: Optional[UUID]
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- def from_celery_request(cls, request) -> "TaskHeaders":
|
|
|
|
- if request.headers and HEADER_FIELD in request.headers:
|
|
|
|
- return TaskHeaders(**request.headers[HEADER_FIELD])
|
|
|
|
|
|
+ def from_kwargs(cls, kwargs: Json) -> Tuple["TaskHeaders", Json]:
|
|
|
|
+ if HEADER_FIELD in kwargs:
|
|
|
|
+ kwargs = kwargs.copy()
|
|
|
|
+ headers = kwargs.pop(HEADER_FIELD)
|
|
|
|
+ return TaskHeaders(**headers), kwargs
|
|
else:
|
|
else:
|
|
- return TaskHeaders(tenant=None, correlation_id=None)
|
|
|
|
|
|
+ return TaskHeaders(tenant=None, correlation_id=None), kwargs
|
|
|
|
|
|
|
|
|
|
class BaseTask(Task):
|
|
class BaseTask(Task):
|
|
def apply_async(self, args=None, kwargs=None, **options):
|
|
def apply_async(self, args=None, kwargs=None, **options):
|
|
- # include correlation_id and tenant in the headers
|
|
|
|
- if "headers" in options:
|
|
|
|
- headers = options.pop("headers")
|
|
|
|
- if headers is None:
|
|
|
|
- headers = {}
|
|
|
|
- else:
|
|
|
|
- headers = headers.copy()
|
|
|
|
- else:
|
|
|
|
- headers = {}
|
|
|
|
- if HEADER_FIELD not in headers:
|
|
|
|
- headers[HEADER_FIELD] = TaskHeaders(
|
|
|
|
- tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
|
|
|
|
- ).model_dump(mode="json")
|
|
|
|
- return super().apply_async(args, kwargs, headers=headers, **options)
|
|
|
|
|
|
+ # include correlation_id and tenant in the kwargs
|
|
|
|
+ # and NOT the headers as that is buggy in celery
|
|
|
|
+ # see https://github.com/celery/celery/issues/4875
|
|
|
|
+ kwargs = {} if kwargs is None else kwargs.copy()
|
|
|
|
+ kwargs[HEADER_FIELD] = TaskHeaders(
|
|
|
|
+ tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
|
|
|
|
+ ).model_dump(mode="json")
|
|
|
|
+ return super().apply_async(args, kwargs, **options)
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
def __call__(self, *args, **kwargs):
|
|
return copy_context().run(self._call_with_context, *args, **kwargs)
|
|
return copy_context().run(self._call_with_context, *args, **kwargs)
|
|
|
|
|
|
def _call_with_context(self, *args, **kwargs):
|
|
def _call_with_context(self, *args, **kwargs):
|
|
- headers = TaskHeaders.from_celery_request(self.request)
|
|
|
|
|
|
+ headers, kwargs = TaskHeaders.from_kwargs(kwargs)
|
|
ctx.tenant = headers.tenant
|
|
ctx.tenant = headers.tenant
|
|
ctx.correlation_id = headers.correlation_id
|
|
ctx.correlation_id = headers.correlation_id
|
|
return super().__call__(*args, **kwargs)
|
|
return super().__call__(*args, **kwargs)
|