dramatiq_task_logger.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # (c) Nelen & Schuurmans
  2. import os
  3. import threading
  4. import time
  5. from typing import Optional
  6. import inject
  7. from dramatiq import get_encoder
  8. from dramatiq import Middleware
  9. from dramatiq.errors import RateLimitExceeded
  10. from dramatiq.errors import Retry
  11. from dramatiq.message import Message
  12. from dramatiq.middleware import SkipMessage
  13. from clean_python import ctx
  14. from clean_python import Gateway
  15. from clean_python.fluentbit import FluentbitGateway
  16. __all__ = ["AsyncLoggingMiddleware", "DramatiqTaskLogger"]
  17. class AsyncLoggingMiddleware(Middleware):
  18. def __init__(self, **kwargs):
  19. self.logger = DramatiqTaskLogger(**kwargs)
  20. def before_process_message(self, broker, message):
  21. broker.run_coroutine(self.logger.start())
  22. def after_skip_message(self, broker, message):
  23. broker.run_coroutine(self.logger.stop(message, None, SkipMessage()))
  24. def after_process_message(self, broker, message, *, result=None, exception=None):
  25. broker.run_coroutine(self.logger.stop(message, result, exception))
  26. class DramatiqTaskLogger:
  27. local = threading.local()
  28. def __init__(
  29. self,
  30. hostname: str,
  31. gateway_override: Optional[Gateway] = None,
  32. ):
  33. self.origin = f"{hostname}-{os.getpid()}"
  34. self.gateway_override = gateway_override
  35. @property
  36. def gateway(self):
  37. return self.gateway_override or inject.instance(FluentbitGateway)
  38. @property
  39. def encoder(self):
  40. return get_encoder()
  41. async def start(self):
  42. self.local.start_time = time.time()
  43. async def stop(self, message: Message, result=None, exception=None):
  44. if exception is None:
  45. state = "SUCCESS"
  46. elif isinstance(exception, Retry):
  47. state = "RETRY"
  48. elif isinstance(exception, SkipMessage):
  49. state = "EXPIRED"
  50. elif isinstance(exception, RateLimitExceeded):
  51. state = "TERMINATED"
  52. else:
  53. state = "FAILURE"
  54. try:
  55. duration = time.time() - self.local.start_time
  56. except AttributeError:
  57. duration = 0
  58. try:
  59. start_time = self.local.start_time
  60. except AttributeError:
  61. start_time = None
  62. log_dict = {
  63. "tag_suffix": "task_log",
  64. "task_id": message.message_id,
  65. "name": message.actor_name,
  66. "state": state,
  67. "duration": duration,
  68. "retries": message.options.get("retries", 0),
  69. "origin": self.origin,
  70. "argsrepr": self.encoder.encode(message.args),
  71. "kwargsrepr": self.encoder.encode(message.kwargs),
  72. "result": result,
  73. "time": start_time,
  74. "correlation_id": str(ctx.correlation_id) if ctx.correlation_id else None,
  75. }
  76. return await self.gateway.add(log_dict)