| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 | # (c) Nelen & Schuurmansimport jsonimport threadingimport timefrom typing import Anyfrom typing import Optionalimport injectfrom billiard.einfo import ExceptionInfofrom celery import Taskfrom celery.signals import task_failurefrom celery.signals import task_postrunfrom celery.signals import task_prerunfrom celery.signals import task_revokedfrom celery.signals import task_successfrom celery.states import FAILUREfrom celery.states import RETRYfrom celery.states import REVOKEDfrom celery.states import SUCCESSfrom clean_python import SyncGatewayfrom clean_python.fluentbit import SyncFluentbitGatewayfrom .base_task import TaskHeaders__all__ = ["CeleryTaskLogger", "set_task_logger"]class CeleryTaskLogger:    local = threading.local()    def __init__(self, gateway_override: Optional[SyncGateway] = None):        self.gateway_override = gateway_override    @property    def gateway(self) -> SyncGateway:        return self.gateway_override or inject.instance(SyncFluentbitGateway)    def start(self):        self.local.start_time = time.time()    def stop(self, task: Task, state: str, result: Any = None):        # format the result into a dict (elasticsearch will error otherwise)        if result is not None and not isinstance(result, dict):            result = {"result": result}        try:            result_json = json.loads(json.dumps(result))        except TypeError:            result_json = None        try:            start_time = self.local.start_time        except AttributeError:            start_time = None        self.local.start_time = None        if start_time is not None:            duration = time.time() - start_time        else:            duration = None        try:            request = task.request            correlation_id = TaskHeaders.from_celery_request(request).correlation_id        except AttributeError:            request = None            correlation_id = None        log_dict = {            "tag_suffix": "task_log",            "time": start_time,            "task_id": getattr(request, "id", None),            "name": task.name,            "state": state,            "duration": duration,            "origin": getattr(request, "origin", None),            "retries": getattr(request, "retries", None),            "argsrepr": getattr(request, "argsrepr", None),            "kwargsrepr": getattr(request, "kwargsrepr", None),            "result": result_json,            "correlation_id": str(correlation_id) if correlation_id else None,        }        return self.gateway.add(log_dict)celery_logger: Optional[CeleryTaskLogger] = Nonedef set_task_logger(logger: Optional[CeleryTaskLogger]):    global celery_logger    celery_logger = logger@task_prerun.connectdef task_prerun_log(**kwargs):    if celery_logger is None:        return    celery_logger.start()@task_postrun.connectdef task_postrun_log(sender: Task, state: str, **kwargs):    if celery_logger is None:        return    if state not in {None, SUCCESS, FAILURE, RETRY}:        celery_logger.stop(task=sender, state=state)@task_success.connectdef task_success_log(sender: Task, result: Any, **kwargs):    if celery_logger is None:        return    celery_logger.stop(task=sender, state=SUCCESS, result=result)@task_failure.connectdef task_failure_log(sender: Task, einfo: ExceptionInfo, **kwargs):    if celery_logger is None:        return    celery_logger.stop(        task=sender, state=FAILURE, result={"traceback": einfo.traceback}    )@task_revoked.connect(dispatch_uid="task_revoked_log")def task_revoked_log(sender: Task, **kwargs):    if celery_logger is None:        return    if str(kwargs["signum"]) == "Signals.SIGTERM":        # This to filter out duplicate logging on task termination.        return    if kwargs["terminated"]:        state = "TERMINATED"    elif kwargs["expired"]:        state = "EXPIRED"    else:        state = REVOKED    celery_logger.stop(task=sender, state=state)
 |