celery_task_logger.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # (c) Nelen & Schuurmans
  2. import json
  3. import threading
  4. import time
  5. from typing import Any
  6. from typing import Optional
  7. import inject
  8. from billiard.einfo import ExceptionInfo
  9. from celery import Task
  10. from celery.signals import task_failure
  11. from celery.signals import task_postrun
  12. from celery.signals import task_prerun
  13. from celery.signals import task_revoked
  14. from celery.signals import task_success
  15. from celery.states import FAILURE
  16. from celery.states import RETRY
  17. from celery.states import REVOKED
  18. from celery.states import SUCCESS
  19. from clean_python import SyncGateway
  20. from clean_python.fluentbit import SyncFluentbitGateway
  21. from .base_task import TaskHeaders
  22. __all__ = ["CeleryTaskLogger", "set_task_logger"]
  23. class CeleryTaskLogger:
  24. local = threading.local()
  25. def __init__(self, gateway_override: Optional[SyncGateway] = None):
  26. self.gateway_override = gateway_override
  27. @property
  28. def gateway(self) -> SyncGateway:
  29. return self.gateway_override or inject.instance(SyncFluentbitGateway)
  30. def start(self):
  31. self.local.start_time = time.time()
  32. def stop(self, task: Task, state: str, result: Any = None):
  33. # format the result into a dict (elasticsearch will error otherwise)
  34. if result is not None and not isinstance(result, dict):
  35. result = {"result": result}
  36. try:
  37. result_json = json.loads(json.dumps(result))
  38. except TypeError:
  39. result_json = None
  40. try:
  41. start_time = self.local.start_time
  42. except AttributeError:
  43. start_time = None
  44. self.local.start_time = None
  45. if start_time is not None:
  46. duration = time.time() - start_time
  47. else:
  48. duration = None
  49. try:
  50. request = task.request
  51. correlation_id = TaskHeaders.from_celery_request(request).correlation_id
  52. except AttributeError:
  53. request = None
  54. correlation_id = None
  55. log_dict = {
  56. "tag_suffix": "task_log",
  57. "time": start_time,
  58. "task_id": getattr(request, "id", None),
  59. "name": task.name,
  60. "state": state,
  61. "duration": duration,
  62. "origin": getattr(request, "origin", None),
  63. "retries": getattr(request, "retries", None),
  64. "argsrepr": getattr(request, "argsrepr", None),
  65. "kwargsrepr": getattr(request, "kwargsrepr", None),
  66. "result": result_json,
  67. "correlation_id": str(correlation_id) if correlation_id else None,
  68. }
  69. return self.gateway.add(log_dict)
  70. celery_logger: Optional[CeleryTaskLogger] = None
  71. def set_task_logger(logger: Optional[CeleryTaskLogger]):
  72. global celery_logger
  73. celery_logger = logger
  74. @task_prerun.connect
  75. def task_prerun_log(**kwargs):
  76. if celery_logger is None:
  77. return
  78. celery_logger.start()
  79. @task_postrun.connect
  80. def task_postrun_log(sender: Task, state: str, **kwargs):
  81. if celery_logger is None:
  82. return
  83. if state not in {None, SUCCESS, FAILURE, RETRY}:
  84. celery_logger.stop(task=sender, state=state)
  85. @task_success.connect
  86. def task_success_log(sender: Task, result: Any, **kwargs):
  87. if celery_logger is None:
  88. return
  89. celery_logger.stop(task=sender, state=SUCCESS, result=result)
  90. @task_failure.connect
  91. def task_failure_log(sender: Task, einfo: ExceptionInfo, **kwargs):
  92. if celery_logger is None:
  93. return
  94. celery_logger.stop(
  95. task=sender, state=FAILURE, result={"traceback": einfo.traceback}
  96. )
  97. @task_revoked.connect(dispatch_uid="task_revoked_log")
  98. def task_revoked_log(sender: Task, **kwargs):
  99. if celery_logger is None:
  100. return
  101. if str(kwargs["signum"]) == "Signals.SIGTERM":
  102. # This to filter out duplicate logging on task termination.
  103. return
  104. if kwargs["terminated"]:
  105. state = "TERMINATED"
  106. elif kwargs["expired"]:
  107. state = "EXPIRED"
  108. else:
  109. state = REVOKED
  110. celery_logger.stop(task=sender, state=state)