dramatiq_task_logger.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. import threading
  3. import time
  4. from typing import Optional
  5. import inject
  6. from dramatiq import get_encoder
  7. from dramatiq import Middleware
  8. from dramatiq.errors import RateLimitExceeded
  9. from dramatiq.errors import Retry
  10. from dramatiq.message import Message
  11. from dramatiq.middleware import SkipMessage
  12. from clean_python.base.infrastructure.gateway import Gateway
  13. from clean_python.fluentbit.fluentbit_gateway import FluentbitGateway
  14. __all__ = ["AsyncLoggingMiddleware"]
  15. class AsyncLoggingMiddleware(Middleware):
  16. def __init__(self, **kwargs):
  17. self.logger = DramatiqTaskLogger(**kwargs)
  18. def before_process_message(self, broker, message):
  19. broker.run_coroutine(self.logger.start())
  20. def after_skip_message(self, broker, message):
  21. broker.run_coroutine(self.logger.stop(message, None, SkipMessage()))
  22. def after_process_message(self, broker, message, *, result=None, exception=None):
  23. broker.run_coroutine(self.logger.stop(message, result, exception))
  24. class DramatiqTaskLogger:
  25. local = threading.local()
  26. def __init__(
  27. self,
  28. hostname: str,
  29. gateway_override: Optional[Gateway] = None,
  30. ):
  31. self.origin = f"{hostname}-{os.getpid()}"
  32. self.gateway_override = gateway_override
  33. @property
  34. def gateway(self):
  35. return self.gateway_override or inject.instance(FluentbitGateway)
  36. @property
  37. def encoder(self):
  38. return get_encoder()
  39. async def start(self):
  40. self.local.start_time = time.time()
  41. async def stop(self, message: Message, result=None, exception=None):
  42. if exception is None:
  43. state = "SUCCESS"
  44. elif isinstance(exception, Retry):
  45. state = "RETRY"
  46. elif isinstance(exception, SkipMessage):
  47. state = "EXPIRED"
  48. elif isinstance(exception, RateLimitExceeded):
  49. state = "TERMINATED"
  50. else:
  51. state = "FAILURE"
  52. try:
  53. duration = time.time() - self.local.start_time
  54. except AttributeError:
  55. duration = 0
  56. log_dict = {
  57. "tag_suffix": "task_log",
  58. "task_id": message.message_id,
  59. "name": message.actor_name,
  60. "state": state,
  61. "duration": duration,
  62. "retries": message.options.get("retries", 0),
  63. "origin": self.origin,
  64. "argsrepr": self.encoder.encode(message.args),
  65. "kwargsrepr": self.encoder.encode(message.kwargs),
  66. "result": result,
  67. }
  68. return await self.gateway.add(log_dict)