123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- # (c) Nelen & Schuurmans
- import os
- import threading
- import time
- from typing import Optional
- from uuid import UUID
- from uuid import uuid4
- import inject
- from dramatiq import get_encoder
- from dramatiq import Message
- from dramatiq import Middleware
- from dramatiq.errors import RateLimitExceeded
- from dramatiq.errors import Retry
- from dramatiq.middleware import SkipMessage
- from clean_python import ctx
- from clean_python import Gateway
- from clean_python.fluentbit import FluentbitGateway
- __all__ = ["AsyncLoggingMiddleware", "DramatiqTaskLogger"]
- class AsyncLoggingMiddleware(Middleware):
- def __init__(self, **kwargs):
- self.logger = DramatiqTaskLogger(**kwargs)
- def before_enqueue(self, broker, message: Message, delay):
- if ctx.correlation_id is not None:
- message.options["correlation_id"] = str(ctx.correlation_id)
- def before_process_message(self, broker, message):
- if message.options.get("correlation_id") is not None:
- ctx.correlation_id = UUID(message.options["correlation_id"])
- else:
- ctx.correlation_id = uuid4()
- broker.run_coroutine(self.logger.start())
- def after_skip_message(self, broker, message):
- broker.run_coroutine(self.logger.stop(message, None, SkipMessage()))
- ctx.correlation_id = None
- def after_process_message(self, broker, message, *, result=None, exception=None):
- broker.run_coroutine(self.logger.stop(message, result, exception))
- ctx.correlation_id = None
- class DramatiqTaskLogger:
- local = threading.local()
- def __init__(
- self,
- hostname: str,
- gateway_override: Optional[Gateway] = None,
- ):
- self.origin = f"{hostname}-{os.getpid()}"
- self.gateway_override = gateway_override
- @property
- def gateway(self):
- return self.gateway_override or inject.instance(FluentbitGateway)
- @property
- def encoder(self):
- return get_encoder()
- async def start(self):
- self.local.start_time = time.time()
- async def stop(self, message: Message, result=None, exception=None):
- if exception is None:
- state = "SUCCESS"
- elif isinstance(exception, Retry):
- state = "RETRY"
- elif isinstance(exception, SkipMessage):
- state = "EXPIRED"
- elif isinstance(exception, RateLimitExceeded):
- state = "TERMINATED"
- else:
- state = "FAILURE"
- try:
- duration = time.time() - self.local.start_time
- except AttributeError:
- duration = 0
- try:
- start_time = self.local.start_time
- except AttributeError:
- start_time = None
- log_dict = {
- "tag_suffix": "task_log",
- "task_id": message.message_id,
- "name": message.actor_name,
- "state": state,
- "duration": duration,
- "retries": message.options.get("retries", 0),
- "origin": self.origin,
- "argsrepr": self.encoder.encode(message.args),
- "kwargsrepr": self.encoder.encode(message.kwargs),
- "result": result,
- "time": start_time,
- "correlation_id": str(ctx.correlation_id) if ctx.correlation_id else None,
- }
- return await self.gateway.add(log_dict)
|