base_task.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from contextvars import copy_context
  2. from typing import Optional
  3. from uuid import UUID
  4. from uuid import uuid4
  5. from celery import Task
  6. from clean_python import ctx
  7. from clean_python import Tenant
  8. from clean_python import ValueObject
  9. __all__ = ["BaseTask"]
  10. HEADER_FIELD = "clean_python_context"
  11. class TaskHeaders(ValueObject):
  12. tenant: Optional[Tenant]
  13. correlation_id: Optional[UUID]
  14. @classmethod
  15. def from_celery_request(cls, request) -> "TaskHeaders":
  16. if request.headers and HEADER_FIELD in request.headers:
  17. return TaskHeaders(**request.headers[HEADER_FIELD])
  18. else:
  19. return TaskHeaders(tenant=None, correlation_id=None)
  20. class BaseTask(Task):
  21. def apply_async(self, args=None, kwargs=None, **options):
  22. # include correlation_id and tenant in the headers
  23. if "headers" in options:
  24. headers = options.pop("headers")
  25. if headers is None:
  26. headers = {}
  27. else:
  28. headers = headers.copy()
  29. else:
  30. headers = {}
  31. if HEADER_FIELD not in headers:
  32. headers[HEADER_FIELD] = TaskHeaders(
  33. tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
  34. ).model_dump(mode="json")
  35. return super().apply_async(args, kwargs, headers=headers, **options)
  36. def __call__(self, *args, **kwargs):
  37. return copy_context().run(self._call_with_context, *args, **kwargs)
  38. def _call_with_context(self, *args, **kwargs):
  39. headers = TaskHeaders.from_celery_request(self.request)
  40. ctx.tenant = headers.tenant
  41. ctx.correlation_id = headers.correlation_id
  42. return super().__call__(*args, **kwargs)