123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- # (c) Nelen & Schuurmans
- import json
- import threading
- import time
- from typing import Any
- from typing import Optional
- import inject
- from billiard.einfo import ExceptionInfo
- from celery import Task
- from celery.signals import task_failure
- from celery.signals import task_postrun
- from celery.signals import task_prerun
- from celery.signals import task_revoked
- from celery.signals import task_success
- from celery.states import FAILURE
- from celery.states import RETRY
- from celery.states import REVOKED
- from celery.states import SUCCESS
- from clean_python import SyncGateway
- from clean_python.fluentbit import SyncFluentbitGateway
- from .base_task import TaskHeaders
- __all__ = ["CeleryTaskLogger", "set_task_logger"]
- class CeleryTaskLogger:
- local = threading.local()
- def __init__(self, gateway_override: Optional[SyncGateway] = None):
- self.gateway_override = gateway_override
- @property
- def gateway(self) -> SyncGateway:
- return self.gateway_override or inject.instance(SyncFluentbitGateway)
- def start(self):
- self.local.start_time = time.time()
- def stop(self, task: Task, state: str, result: Any = None):
- # format the result into a dict (elasticsearch will error otherwise)
- if result is not None and not isinstance(result, dict):
- result = {"result": result}
- try:
- result_json = json.loads(json.dumps(result))
- except TypeError:
- result_json = None
- try:
- start_time = self.local.start_time
- except AttributeError:
- start_time = None
- self.local.start_time = None
- if start_time is not None:
- duration = time.time() - start_time
- else:
- duration = None
- try:
- request = task.request
- except AttributeError:
- request = None
- try:
- headers, kwargs = TaskHeaders.from_kwargs(request.kwargs)
- except (AttributeError, TypeError):
- headers = kwargs = None # type: ignore
- try:
- tenant_id = headers.tenant.id # type: ignore
- except AttributeError:
- tenant_id = None
- try:
- correlation_id = headers.correlation_id
- except AttributeError:
- correlation_id = None
- try:
- argsrepr = json.dumps(request.args)
- except (AttributeError, TypeError):
- argsrepr = None
- try:
- kwargsrepr = json.dumps(kwargs)
- except TypeError:
- kwargsrepr = None
- log_dict = {
- "tag_suffix": "task_log",
- "time": start_time,
- "task_id": getattr(request, "id", None),
- "name": task.name,
- "state": state,
- "duration": duration,
- "origin": getattr(request, "origin", None),
- "retries": getattr(request, "retries", None),
- "argsrepr": argsrepr if argsrepr != "null" else None,
- "kwargsrepr": kwargsrepr if kwargsrepr != "null" else None,
- "result": result_json,
- "tenant_id": tenant_id,
- "correlation_id": None if correlation_id is None else str(correlation_id),
- }
- return self.gateway.add(log_dict)
- celery_logger: Optional[CeleryTaskLogger] = None
- def set_task_logger(logger: Optional[CeleryTaskLogger]):
- global celery_logger
- celery_logger = logger
- @task_prerun.connect
- def task_prerun_log(**kwargs):
- if celery_logger is None:
- return
- celery_logger.start()
- @task_postrun.connect
- def task_postrun_log(sender: Task, state: str, **kwargs):
- if celery_logger is None:
- return
- if state not in {None, SUCCESS, FAILURE, RETRY}:
- celery_logger.stop(task=sender, state=state)
- @task_success.connect
- def task_success_log(sender: Task, result: Any, **kwargs):
- if celery_logger is None:
- return
- celery_logger.stop(task=sender, state=SUCCESS, result=result)
- @task_failure.connect
- def task_failure_log(sender: Task, einfo: ExceptionInfo, **kwargs):
- if celery_logger is None:
- return
- celery_logger.stop(
- task=sender, state=FAILURE, result={"traceback": einfo.traceback}
- )
- @task_revoked.connect(dispatch_uid="task_revoked_log")
- def task_revoked_log(sender: Task, **kwargs):
- if celery_logger is None:
- return
- if str(kwargs["signum"]) == "Signals.SIGTERM":
- # This to filter out duplicate logging on task termination.
- return
- if kwargs["terminated"]:
- state = "TERMINATED"
- elif kwargs["expired"]:
- state = "EXPIRED"
- else:
- state = REVOKED
- celery_logger.stop(task=sender, state=state)
|