base_task.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from contextvars import copy_context
  2. from typing import Optional
  3. from typing import Tuple
  4. from uuid import UUID
  5. from uuid import uuid4
  6. from celery import Task
  7. from clean_python import ctx
  8. from clean_python import Json
  9. from clean_python import Tenant
  10. from clean_python import ValueObject
  11. __all__ = ["BaseTask"]
  12. HEADER_FIELD = "clean_python_context"
  13. class TaskHeaders(ValueObject):
  14. tenant: Optional[Tenant]
  15. correlation_id: Optional[UUID]
  16. @classmethod
  17. def from_kwargs(cls, kwargs: Json) -> Tuple["TaskHeaders", Json]:
  18. if HEADER_FIELD in kwargs:
  19. kwargs = kwargs.copy()
  20. headers = kwargs.pop(HEADER_FIELD)
  21. return TaskHeaders(**headers), kwargs
  22. else:
  23. return TaskHeaders(tenant=None, correlation_id=None), kwargs
  24. class BaseTask(Task):
  25. def apply_async(self, args=None, kwargs=None, **options):
  26. # include correlation_id and tenant in the kwargs
  27. # and NOT the headers as that is buggy in celery
  28. # see https://github.com/celery/celery/issues/4875
  29. kwargs = {} if kwargs is None else kwargs.copy()
  30. kwargs[HEADER_FIELD] = TaskHeaders(
  31. tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
  32. ).model_dump(mode="json")
  33. return super().apply_async(args, kwargs, **options)
  34. def __call__(self, *args, **kwargs):
  35. return copy_context().run(self._call_with_context, *args, **kwargs)
  36. def _call_with_context(self, *args, **kwargs):
  37. headers, kwargs = TaskHeaders.from_kwargs(kwargs)
  38. ctx.tenant = headers.tenant
  39. ctx.correlation_id = headers.correlation_id
  40. return super().__call__(*args, **kwargs)