| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 | # (c) Nelen & Schuurmansimport jsonimport threadingimport timefrom typing import Anyfrom typing import Optionalimport injectfrom billiard.einfo import ExceptionInfofrom celery import Taskfrom celery.signals import task_failurefrom celery.signals import task_postrunfrom celery.signals import task_prerunfrom celery.signals import task_revokedfrom celery.signals import task_successfrom celery.states import FAILUREfrom celery.states import RETRYfrom celery.states import REVOKEDfrom celery.states import SUCCESSfrom clean_python import SyncGatewayfrom clean_python.fluentbit import SyncFluentbitGatewayfrom .base_task import TaskHeaders__all__ = ["CeleryTaskLogger", "set_task_logger"]class CeleryTaskLogger:    local = threading.local()    def __init__(self, gateway_override: Optional[SyncGateway] = None):        self.gateway_override = gateway_override    @property    def gateway(self) -> SyncGateway:        return self.gateway_override or inject.instance(SyncFluentbitGateway)    def start(self):        self.local.start_time = time.time()    def stop(self, task: Task, state: str, result: Any = None):        # format the result into a dict (elasticsearch will error otherwise)        if result is not None and not isinstance(result, dict):            result = {"result": result}        try:            result_json = json.loads(json.dumps(result))        except TypeError:            result_json = None        try:            start_time = self.local.start_time        except AttributeError:            start_time = None        self.local.start_time = None        if start_time is not None:            duration = time.time() - start_time        else:            duration = None        try:            request = task.request        except AttributeError:            request = None        try:            headers, kwargs = TaskHeaders.from_kwargs(request.kwargs)        except AttributeError:            headers = kwargs = None  # type: ignore        try:            tenant_id = headers.tenant.id  # type: ignore        except AttributeError:            tenant_id = None        try:            correlation_id = headers.correlation_id        except AttributeError:            correlation_id = None        try:            args = json.loads(json.dumps(request.args))        except (AttributeError, TypeError):            args = None        try:            kwargs = json.loads(json.dumps(kwargs))        except TypeError:            kwargs = None        log_dict = {            "tag_suffix": "task_log",            "time": start_time,            "task_id": getattr(request, "id", None),            "name": task.name,            "state": state,            "duration": duration,            "origin": getattr(request, "origin", None),            "retries": getattr(request, "retries", None),            "args": args,            "kwargs": kwargs,            "result": result_json,            "tenant_id": tenant_id,            "correlation_id": None if correlation_id is None else str(correlation_id),        }        return self.gateway.add(log_dict)celery_logger: Optional[CeleryTaskLogger] = Nonedef set_task_logger(logger: Optional[CeleryTaskLogger]):    global celery_logger    celery_logger = logger@task_prerun.connectdef task_prerun_log(**kwargs):    if celery_logger is None:        return    celery_logger.start()@task_postrun.connectdef task_postrun_log(sender: Task, state: str, **kwargs):    if celery_logger is None:        return    if state not in {None, SUCCESS, FAILURE, RETRY}:        celery_logger.stop(task=sender, state=state)@task_success.connectdef task_success_log(sender: Task, result: Any, **kwargs):    if celery_logger is None:        return    celery_logger.stop(task=sender, state=SUCCESS, result=result)@task_failure.connectdef task_failure_log(sender: Task, einfo: ExceptionInfo, **kwargs):    if celery_logger is None:        return    celery_logger.stop(        task=sender, state=FAILURE, result={"traceback": einfo.traceback}    )@task_revoked.connect(dispatch_uid="task_revoked_log")def task_revoked_log(sender: Task, **kwargs):    if celery_logger is None:        return    if str(kwargs["signum"]) == "Signals.SIGTERM":        # This to filter out duplicate logging on task termination.        return    if kwargs["terminated"]:        state = "TERMINATED"    elif kwargs["expired"]:        state = "EXPIRED"    else:        state = REVOKED    celery_logger.stop(task=sender, state=state)
 |