test_async_actor.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import threading
  2. from asyncio import BaseEventLoop
  3. from contextvars import ContextVar
  4. from unittest import mock
  5. import pytest
  6. from clean_python.dramatiq import async_actor
  7. from clean_python.dramatiq import AsyncActor
  8. from clean_python.dramatiq import AsyncMiddleware
  9. from clean_python.dramatiq.async_actor import EventLoopThread
  10. @pytest.fixture
  11. def started_thread():
  12. thread = EventLoopThread()
  13. thread.start()
  14. yield thread
  15. thread.join()
  16. def test_event_loop_thread_start():
  17. try:
  18. thread = EventLoopThread()
  19. thread.start()
  20. assert isinstance(thread.loop, BaseEventLoop)
  21. assert thread.loop.is_running()
  22. finally:
  23. thread.join()
  24. def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread):
  25. result = {}
  26. async def get_thread_id():
  27. result["thread_id"] = threading.get_ident()
  28. started_thread.run_coroutine(get_thread_id())
  29. # the coroutine executed in the event loop thread
  30. assert result["thread_id"] == started_thread.ident
  31. def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread):
  32. async def raise_error():
  33. raise TypeError("bla")
  34. coro = raise_error()
  35. with pytest.raises(TypeError, match="bla"):
  36. started_thread.run_coroutine(coro)
  37. @mock.patch.object(EventLoopThread, "start")
  38. @mock.patch.object(EventLoopThread, "run_coroutine")
  39. def test_async_middleware_before_worker_boot(
  40. EventLoopThread_run_coroutine, EventLoopThread_start
  41. ):
  42. broker = mock.Mock()
  43. worker = mock.Mock()
  44. middleware = AsyncMiddleware()
  45. middleware.before_worker_boot(broker, worker)
  46. assert isinstance(middleware.event_loop_thread, EventLoopThread)
  47. EventLoopThread_start.assert_called_once()
  48. middleware.run_coroutine("foo")
  49. EventLoopThread_run_coroutine.assert_called_once_with("foo")
  50. # broker was patched with run_coroutine
  51. broker.run_coroutine("bar")
  52. EventLoopThread_run_coroutine.assert_called_with("bar")
  53. def test_async_middleware_after_worker_shutdown():
  54. broker = mock.Mock()
  55. broker.run_coroutine = lambda x: x
  56. worker = mock.Mock()
  57. event_loop_thread = mock.Mock()
  58. middleware = AsyncMiddleware()
  59. middleware.event_loop_thread = event_loop_thread
  60. middleware.after_worker_shutdown(broker, worker)
  61. event_loop_thread.join.assert_called_once()
  62. assert middleware.event_loop_thread is None
  63. assert not hasattr(broker, "run_coroutine")
  64. def test_async_actor():
  65. broker = mock.Mock()
  66. broker.actor_options = {"max_retries"}
  67. @async_actor(broker=broker)
  68. async def foo(*args, **kwargs):
  69. pass
  70. assert isinstance(foo, AsyncActor)
  71. foo(2, a="b")
  72. broker.run_coroutine.assert_called_once()
  73. # no recursion errors here:
  74. repr(foo)
  75. foo_var: ContextVar[int] = ContextVar("foo", default=42)
  76. def test_run_coroutine_keeps_context(started_thread: EventLoopThread):
  77. async def return_foo_var():
  78. return foo_var.get()
  79. foo_var.set(31)
  80. assert started_thread.run_coroutine(return_foo_var()) == 31