celery_task_logger.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # (c) Nelen & Schuurmans
  2. import json
  3. import threading
  4. import time
  5. from typing import Any
  6. from typing import Optional
  7. import inject
  8. from billiard.einfo import ExceptionInfo
  9. from celery import Task
  10. from celery.signals import task_failure
  11. from celery.signals import task_postrun
  12. from celery.signals import task_prerun
  13. from celery.signals import task_revoked
  14. from celery.signals import task_success
  15. from celery.states import FAILURE
  16. from celery.states import RETRY
  17. from celery.states import REVOKED
  18. from celery.states import SUCCESS
  19. from clean_python import SyncGateway
  20. from clean_python.fluentbit import SyncFluentbitGateway
  21. from .base_task import TaskHeaders
  22. __all__ = ["CeleryTaskLogger", "set_task_logger"]
  23. class CeleryTaskLogger:
  24. local = threading.local()
  25. def __init__(self, gateway_override: Optional[SyncGateway] = None):
  26. self.gateway_override = gateway_override
  27. @property
  28. def gateway(self) -> SyncGateway:
  29. return self.gateway_override or inject.instance(SyncFluentbitGateway)
  30. def start(self):
  31. self.local.start_time = time.time()
  32. def stop(self, task: Task, state: str, result: Any = None):
  33. # format the result into a dict (elasticsearch will error otherwise)
  34. if result is not None and not isinstance(result, dict):
  35. result = {"result": result}
  36. try:
  37. result_json = json.loads(json.dumps(result))
  38. except TypeError:
  39. result_json = None
  40. try:
  41. start_time = self.local.start_time
  42. except AttributeError:
  43. start_time = None
  44. self.local.start_time = None
  45. if start_time is not None:
  46. duration = time.time() - start_time
  47. else:
  48. duration = None
  49. try:
  50. request = task.request
  51. except AttributeError:
  52. request = None
  53. try:
  54. headers, kwargs = TaskHeaders.from_kwargs(request.kwargs)
  55. except (AttributeError, TypeError):
  56. headers = kwargs = None # type: ignore
  57. try:
  58. tenant_id = headers.tenant.id # type: ignore
  59. except AttributeError:
  60. tenant_id = None
  61. try:
  62. correlation_id = headers.correlation_id
  63. except AttributeError:
  64. correlation_id = None
  65. try:
  66. argsrepr = json.dumps(request.args)
  67. except (AttributeError, TypeError):
  68. argsrepr = None
  69. try:
  70. kwargsrepr = json.dumps(kwargs)
  71. except TypeError:
  72. kwargsrepr = None
  73. log_dict = {
  74. "tag_suffix": "task_log",
  75. "time": start_time,
  76. "task_id": getattr(request, "id", None),
  77. "name": task.name,
  78. "state": state,
  79. "duration": duration,
  80. "origin": getattr(request, "origin", None),
  81. "retries": getattr(request, "retries", None),
  82. "argsrepr": argsrepr if argsrepr != "null" else None,
  83. "kwargsrepr": kwargsrepr if kwargsrepr != "null" else None,
  84. "result": result_json,
  85. "tenant_id": tenant_id,
  86. "correlation_id": None if correlation_id is None else str(correlation_id),
  87. }
  88. return self.gateway.add(log_dict)
  89. celery_logger: Optional[CeleryTaskLogger] = None
  90. def set_task_logger(logger: Optional[CeleryTaskLogger]):
  91. global celery_logger
  92. celery_logger = logger
  93. @task_prerun.connect
  94. def task_prerun_log(**kwargs):
  95. if celery_logger is None:
  96. return
  97. celery_logger.start()
  98. @task_postrun.connect
  99. def task_postrun_log(sender: Task, state: str, **kwargs):
  100. if celery_logger is None:
  101. return
  102. if state not in {None, SUCCESS, FAILURE, RETRY}:
  103. celery_logger.stop(task=sender, state=state)
  104. @task_success.connect
  105. def task_success_log(sender: Task, result: Any, **kwargs):
  106. if celery_logger is None:
  107. return
  108. celery_logger.stop(task=sender, state=SUCCESS, result=result)
  109. @task_failure.connect
  110. def task_failure_log(sender: Task, einfo: ExceptionInfo, **kwargs):
  111. if celery_logger is None:
  112. return
  113. celery_logger.stop(
  114. task=sender, state=FAILURE, result={"traceback": einfo.traceback}
  115. )
  116. @task_revoked.connect(dispatch_uid="task_revoked_log")
  117. def task_revoked_log(sender: Task, **kwargs):
  118. if celery_logger is None:
  119. return
  120. if str(kwargs["signum"]) == "Signals.SIGTERM":
  121. # This to filter out duplicate logging on task termination.
  122. return
  123. if kwargs["terminated"]:
  124. state = "TERMINATED"
  125. elif kwargs["expired"]:
  126. state = "EXPIRED"
  127. else:
  128. state = REVOKED
  129. celery_logger.stop(task=sender, state=state)