dramatiq_task_logger.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # (c) Nelen & Schuurmans
  2. import os
  3. import threading
  4. import time
  5. from typing import Optional
  6. import inject
  7. from dramatiq import get_encoder
  8. from dramatiq import Middleware
  9. from dramatiq.errors import RateLimitExceeded
  10. from dramatiq.errors import Retry
  11. from dramatiq.message import Message
  12. from dramatiq.middleware import SkipMessage
  13. from clean_python import Gateway
  14. from clean_python.fluentbit import FluentbitGateway
  15. __all__ = ["AsyncLoggingMiddleware", "DramatiqTaskLogger"]
  16. class AsyncLoggingMiddleware(Middleware):
  17. def __init__(self, **kwargs):
  18. self.logger = DramatiqTaskLogger(**kwargs)
  19. def before_process_message(self, broker, message):
  20. broker.run_coroutine(self.logger.start())
  21. def after_skip_message(self, broker, message):
  22. broker.run_coroutine(self.logger.stop(message, None, SkipMessage()))
  23. def after_process_message(self, broker, message, *, result=None, exception=None):
  24. broker.run_coroutine(self.logger.stop(message, result, exception))
  25. class DramatiqTaskLogger:
  26. local = threading.local()
  27. def __init__(
  28. self,
  29. hostname: str,
  30. gateway_override: Optional[Gateway] = None,
  31. ):
  32. self.origin = f"{hostname}-{os.getpid()}"
  33. self.gateway_override = gateway_override
  34. @property
  35. def gateway(self):
  36. return self.gateway_override or inject.instance(FluentbitGateway)
  37. @property
  38. def encoder(self):
  39. return get_encoder()
  40. async def start(self):
  41. self.local.start_time = time.time()
  42. async def stop(self, message: Message, result=None, exception=None):
  43. if exception is None:
  44. state = "SUCCESS"
  45. elif isinstance(exception, Retry):
  46. state = "RETRY"
  47. elif isinstance(exception, SkipMessage):
  48. state = "EXPIRED"
  49. elif isinstance(exception, RateLimitExceeded):
  50. state = "TERMINATED"
  51. else:
  52. state = "FAILURE"
  53. try:
  54. duration = time.time() - self.local.start_time
  55. except AttributeError:
  56. duration = 0
  57. log_dict = {
  58. "tag_suffix": "task_log",
  59. "task_id": message.message_id,
  60. "name": message.actor_name,
  61. "state": state,
  62. "duration": duration,
  63. "retries": message.options.get("retries", 0),
  64. "origin": self.origin,
  65. "argsrepr": self.encoder.encode(message.args),
  66. "kwargsrepr": self.encoder.encode(message.kwargs),
  67. "result": result,
  68. }
  69. return await self.gateway.add(log_dict)