dramatiq_task_logger.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # (c) Nelen & Schuurmans
  2. import os
  3. import threading
  4. import time
  5. from typing import Optional
  6. from uuid import UUID
  7. from uuid import uuid4
  8. import inject
  9. from dramatiq import get_encoder
  10. from dramatiq import Message
  11. from dramatiq import Middleware
  12. from dramatiq.errors import RateLimitExceeded
  13. from dramatiq.errors import Retry
  14. from dramatiq.middleware import SkipMessage
  15. from clean_python import ctx
  16. from clean_python import Gateway
  17. from clean_python.fluentbit import FluentbitGateway
  18. __all__ = ["AsyncLoggingMiddleware", "DramatiqTaskLogger"]
  19. class AsyncLoggingMiddleware(Middleware):
  20. def __init__(self, **kwargs):
  21. self.logger = DramatiqTaskLogger(**kwargs)
  22. def before_enqueue(self, broker, message: Message, delay):
  23. if ctx.correlation_id is not None:
  24. message.options["correlation_id"] = str(ctx.correlation_id)
  25. def before_process_message(self, broker, message):
  26. if message.options.get("correlation_id") is not None:
  27. ctx.correlation_id = UUID(message.options["correlation_id"])
  28. else:
  29. ctx.correlation_id = uuid4()
  30. broker.run_coroutine(self.logger.start())
  31. def after_skip_message(self, broker, message):
  32. broker.run_coroutine(self.logger.stop(message, None, SkipMessage()))
  33. ctx.correlation_id = None
  34. def after_process_message(self, broker, message, *, result=None, exception=None):
  35. broker.run_coroutine(self.logger.stop(message, result, exception))
  36. ctx.correlation_id = None
  37. class DramatiqTaskLogger:
  38. local = threading.local()
  39. def __init__(
  40. self,
  41. hostname: str,
  42. gateway_override: Optional[Gateway] = None,
  43. ):
  44. self.origin = f"{hostname}-{os.getpid()}"
  45. self.gateway_override = gateway_override
  46. @property
  47. def gateway(self):
  48. return self.gateway_override or inject.instance(FluentbitGateway)
  49. @property
  50. def encoder(self):
  51. return get_encoder()
  52. async def start(self):
  53. self.local.start_time = time.time()
  54. async def stop(self, message: Message, result=None, exception=None):
  55. if exception is None:
  56. state = "SUCCESS"
  57. elif isinstance(exception, Retry):
  58. state = "RETRY"
  59. elif isinstance(exception, SkipMessage):
  60. state = "EXPIRED"
  61. elif isinstance(exception, RateLimitExceeded):
  62. state = "TERMINATED"
  63. else:
  64. state = "FAILURE"
  65. try:
  66. duration = time.time() - self.local.start_time
  67. except AttributeError:
  68. duration = 0
  69. try:
  70. start_time = self.local.start_time
  71. except AttributeError:
  72. start_time = None
  73. log_dict = {
  74. "tag_suffix": "task_log",
  75. "task_id": message.message_id,
  76. "name": message.actor_name,
  77. "state": state,
  78. "duration": duration,
  79. "retries": message.options.get("retries", 0),
  80. "origin": self.origin,
  81. "argsrepr": self.encoder.encode(message.args),
  82. "kwargsrepr": self.encoder.encode(message.kwargs),
  83. "result": result,
  84. "time": start_time,
  85. "correlation_id": str(ctx.correlation_id) if ctx.correlation_id else None,
  86. }
  87. return await self.gateway.add(log_dict)