| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 | from contextvars import copy_contextfrom typing import Optionalfrom typing import Tuplefrom uuid import UUIDfrom uuid import uuid4from celery import Taskfrom clean_python import ctxfrom clean_python import Jsonfrom clean_python import Tenantfrom clean_python import ValueObject__all__ = ["BaseTask"]HEADER_FIELD = "clean_python_context"class TaskHeaders(ValueObject):    tenant: Optional[Tenant]    correlation_id: Optional[UUID]    @classmethod    def from_kwargs(cls, kwargs: Json) -> Tuple["TaskHeaders", Json]:        if HEADER_FIELD in kwargs:            kwargs = kwargs.copy()            headers = kwargs.pop(HEADER_FIELD)            return TaskHeaders(**headers), kwargs        else:            return TaskHeaders(tenant=None, correlation_id=None), kwargsclass BaseTask(Task):    def apply_async(self, args=None, kwargs=None, **options):        # include correlation_id and tenant in the kwargs        # and NOT the headers as that is buggy in celery        # see  https://github.com/celery/celery/issues/4875        kwargs = {} if kwargs is None else kwargs.copy()        kwargs[HEADER_FIELD] = TaskHeaders(            tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()        ).model_dump(mode="json")        return super().apply_async(args, kwargs, **options)    def __call__(self, *args, **kwargs):        return copy_context().run(self._call_with_context, *args, **kwargs)    def _call_with_context(self, *args, **kwargs):        headers, kwargs = TaskHeaders.from_kwargs(kwargs)        ctx.tenant = headers.tenant        ctx.correlation_id = headers.correlation_id        return super().__call__(*args, **kwargs)
 |