test_async_actor.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import threading
  2. from asyncio import BaseEventLoop
  3. from unittest import mock
  4. import pytest
  5. from clean_python.dramatiq import async_actor
  6. from clean_python.dramatiq import AsyncActor
  7. from clean_python.dramatiq import AsyncMiddleware
  8. from clean_python.dramatiq.async_actor import EventLoopThread
  9. @pytest.fixture
  10. def started_thread():
  11. thread = EventLoopThread()
  12. thread.start()
  13. yield thread
  14. thread.join()
  15. def test_event_loop_thread_start():
  16. try:
  17. thread = EventLoopThread()
  18. thread.start()
  19. assert isinstance(thread.loop, BaseEventLoop)
  20. assert thread.loop.is_running()
  21. finally:
  22. thread.join()
  23. def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread):
  24. result = {}
  25. async def get_thread_id():
  26. result["thread_id"] = threading.get_ident()
  27. started_thread.run_coroutine(get_thread_id())
  28. # the coroutine executed in the event loop thread
  29. assert result["thread_id"] == started_thread.ident
  30. def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread):
  31. async def raise_error():
  32. raise TypeError("bla")
  33. coro = raise_error()
  34. with pytest.raises(TypeError, match="bla"):
  35. started_thread.run_coroutine(coro)
  36. @mock.patch.object(EventLoopThread, "start")
  37. @mock.patch.object(EventLoopThread, "run_coroutine")
  38. def test_async_middleware_before_worker_boot(
  39. EventLoopThread_run_coroutine, EventLoopThread_start
  40. ):
  41. broker = mock.Mock()
  42. worker = mock.Mock()
  43. middleware = AsyncMiddleware()
  44. middleware.before_worker_boot(broker, worker)
  45. assert isinstance(middleware.event_loop_thread, EventLoopThread)
  46. EventLoopThread_start.assert_called_once()
  47. middleware.run_coroutine("foo")
  48. EventLoopThread_run_coroutine.assert_called_once_with("foo")
  49. # broker was patched with run_coroutine
  50. broker.run_coroutine("bar")
  51. EventLoopThread_run_coroutine.assert_called_with("bar")
  52. def test_async_middleware_after_worker_shutdown():
  53. broker = mock.Mock()
  54. broker.run_coroutine = lambda x: x
  55. worker = mock.Mock()
  56. event_loop_thread = mock.Mock()
  57. middleware = AsyncMiddleware()
  58. middleware.event_loop_thread = event_loop_thread
  59. middleware.after_worker_shutdown(broker, worker)
  60. event_loop_thread.join.assert_called_once()
  61. assert middleware.event_loop_thread is None
  62. assert not hasattr(broker, "run_coroutine")
  63. def test_async_actor():
  64. broker = mock.Mock()
  65. broker.actor_options = {"max_retries"}
  66. @async_actor(broker=broker)
  67. async def foo(*args, **kwargs):
  68. pass
  69. assert isinstance(foo, AsyncActor)
  70. foo(2, a="b")
  71. broker.run_coroutine.assert_called_once()
  72. # no recursion errors here:
  73. repr(foo)