base_task.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from contextvars import copy_context
  2. from typing import Optional
  3. from uuid import UUID
  4. from uuid import uuid4
  5. from celery import Task
  6. from clean_python import ctx
  7. from clean_python import Tenant
  8. from clean_python import ValueObject
  9. __all__ = ["BaseTask"]
  10. HEADER_FIELD = "clean_python_context"
  11. class TaskHeaders(ValueObject):
  12. tenant: Optional[Tenant]
  13. correlation_id: Optional[UUID]
  14. @classmethod
  15. def from_celery_request(cls, request) -> "TaskHeaders":
  16. if request.headers and HEADER_FIELD in request.headers:
  17. return TaskHeaders(**request.headers[HEADER_FIELD])
  18. else:
  19. return TaskHeaders(tenant=None, correlation_id=None)
  20. class BaseTask(Task):
  21. def apply_async(self, args=None, kwargs=None, **options):
  22. # include correlation_id and tenant in the headers
  23. if options.get("headers") is not None:
  24. headers = options["headers"].copy()
  25. else:
  26. headers = {}
  27. headers[HEADER_FIELD] = TaskHeaders(
  28. tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
  29. ).model_dump(mode="json")
  30. return super().apply_async(args, kwargs, headers=headers, **options)
  31. def __call__(self, *args, **kwargs):
  32. return copy_context().run(self._call_with_context, *args, **kwargs)
  33. def _call_with_context(self, *args, **kwargs):
  34. headers = TaskHeaders.from_celery_request(self.request)
  35. ctx.tenant = headers.tenant
  36. ctx.correlation_id = headers.correlation_id
  37. return super().__call__(*args, **kwargs)