async_actor.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. """Dramatiq configuration"""
  4. import asyncio
  5. import logging
  6. import threading
  7. import time
  8. from concurrent.futures import TimeoutError
  9. from typing import Any, Awaitable, Dict, Optional, TypeVar
  10. import dramatiq
  11. from asgiref.sync import sync_to_async
  12. from dramatiq.brokers.stub import StubBroker
  13. from dramatiq.middleware import Interrupt, Middleware
  14. __all__ = ["AsyncActor", "AsyncMiddleware", "async_actor"]
  15. logger = logging.getLogger(__name__)
  16. # Default broker (for testing)
  17. broker = StubBroker()
  18. broker.run_coroutine = lambda coro: asyncio.run(coro)
  19. dramatiq.set_broker(broker)
  20. R = TypeVar("R")
  21. class EventLoopThread(threading.Thread):
  22. """A thread that starts / stops an asyncio event loop.
  23. The method 'run_coroutine' should be used to run coroutines from a
  24. synchronous context.
  25. """
  26. EVENT_LOOP_START_TIMEOUT = 0.1 # seconds to wait for the event loop to start
  27. loop: Optional[asyncio.AbstractEventLoop] = None
  28. def __init__(self):
  29. super().__init__(target=self._start_event_loop)
  30. def _start_event_loop(self):
  31. """This method should run in the thread"""
  32. logger.info("Starting the event loop...")
  33. self.loop = asyncio.new_event_loop()
  34. try:
  35. self.loop.run_forever()
  36. finally:
  37. self.loop.close()
  38. def _stop_event_loop(self):
  39. """This method should run outside of the thread"""
  40. if self.loop is not None:
  41. logger.info("Stopping the event loop...")
  42. self.loop.call_soon_threadsafe(self.loop.stop)
  43. def run_coroutine(self, coro: Awaitable[R]) -> R:
  44. """To be called from outside the thread
  45. Blocks until the coroutine is finished.
  46. """
  47. if self.loop is None or not self.loop.is_running():
  48. raise RuntimeError("The event loop is not running")
  49. done = threading.Event()
  50. async def wrapped_coro() -> R:
  51. try:
  52. return await coro
  53. finally:
  54. done.set()
  55. future = asyncio.run_coroutine_threadsafe(wrapped_coro(), self.loop)
  56. try:
  57. while True:
  58. try:
  59. # Use a timeout to be able to catch asynchronously raised dramatiq
  60. # exceptions (Shutdown and TimeLimitExceeded).
  61. return future.result(timeout=1)
  62. except TimeoutError:
  63. continue
  64. except Interrupt:
  65. self.loop.call_soon_threadsafe(future.cancel)
  66. # The future will raise a CancelledError *before* the coro actually
  67. # finished cleanup. Wait for the event instead.
  68. done.wait()
  69. raise
  70. def start(self, *args, **kwargs):
  71. super().start(*args, **kwargs)
  72. time.sleep(self.EVENT_LOOP_START_TIMEOUT)
  73. if self.loop is None or not self.loop.is_running():
  74. logger.exception("The event loop failed to start")
  75. logger.info("Event loop is running.")
  76. def join(self, *args, **kwargs):
  77. self._stop_event_loop()
  78. return super().join(*args, **kwargs)
  79. class AsyncMiddleware(Middleware):
  80. """This middleware enables coroutines to be ran as dramatiq a actors.
  81. At its core, this middleware spins up a dedicated thread ('event_loop_thread'),
  82. which may be used to schedule the coroutines on from the worker threads.
  83. """
  84. event_loop_thread: Optional[EventLoopThread] = None
  85. def run_coroutine(self, coro: Awaitable[R]) -> R:
  86. assert self.event_loop_thread is not None
  87. return self.event_loop_thread.run_coroutine(coro)
  88. def before_worker_boot(self, broker, worker):
  89. self.event_loop_thread = EventLoopThread()
  90. self.event_loop_thread.start()
  91. broker.run_coroutine = self.run_coroutine
  92. def after_worker_shutdown(self, broker, worker):
  93. assert self.event_loop_thread is not None
  94. self.event_loop_thread.join()
  95. self.event_loop_thread = None
  96. delattr(broker, "run_coroutine")
  97. class AsyncActor(dramatiq.Actor):
  98. """To configure coroutines as a dramatiq actor.
  99. Requires AsyncMiddleware to be active.
  100. Example usage:
  101. >>> @dramatiq.actor(..., actor_class=AsyncActor)
  102. ... async def my_task(x):
  103. ... print(x)
  104. Notes:
  105. The async functions are scheduled on an event loop that is shared between
  106. worker threads. See AsyncMiddleware.
  107. This is compatible with ShutdownNotifications ("notify_shutdown") and
  108. TimeLimit ("time_limit"). Both result in an asyncio.CancelledError raised inside
  109. the async function. There is currently no way to tell the two apart.
  110. """
  111. def __init__(self, fn, *args, **kwargs):
  112. super().__init__(
  113. lambda *args, **kwargs: self.broker.run_coroutine(fn(*args, **kwargs)),
  114. *args,
  115. **kwargs,
  116. )
  117. @sync_to_async
  118. def send_async(self, *args, **kwargs) -> dramatiq.Message[R]:
  119. """See dramatiq.actor.Actor.send.
  120. Sending a message to a broker is potentially blocking, so @sync_to_async is used.
  121. """
  122. return super().send(*args, **kwargs)
  123. @sync_to_async
  124. def send_async_with_options(
  125. self,
  126. *,
  127. args: tuple = (),
  128. kwargs: Optional[Dict[str, Any]] = None,
  129. delay: Optional[int] = None,
  130. **options,
  131. ) -> dramatiq.Message[R]:
  132. """See dramatiq.actor.Actor.send_with_options.
  133. Sending a message to a broker is potentially blocking, so @sync_to_async is used.
  134. """
  135. return super().send_with_options(
  136. args=args, kwargs=kwargs, delay=delay, **options
  137. )
  138. def async_actor(awaitable=None, **kwargs):
  139. kwargs.setdefault("max_retries", 0)
  140. if awaitable:
  141. return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
  142. else:
  143. def wrapper(awaitable):
  144. return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
  145. return wrapper