test_async_actor.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import threading
  2. from asyncio import BaseEventLoop
  3. from unittest import mock
  4. import pytest
  5. from clean_python.dramatiq.async_actor import (
  6. async_actor,
  7. AsyncActor,
  8. AsyncMiddleware,
  9. EventLoopThread,
  10. )
  11. @pytest.fixture
  12. def started_thread():
  13. thread = EventLoopThread()
  14. thread.start()
  15. yield thread
  16. thread.join()
  17. def test_event_loop_thread_start():
  18. try:
  19. thread = EventLoopThread()
  20. thread.start()
  21. assert isinstance(thread.loop, BaseEventLoop)
  22. assert thread.loop.is_running()
  23. finally:
  24. thread.join()
  25. def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread):
  26. result = {}
  27. async def get_thread_id():
  28. result["thread_id"] = threading.get_ident()
  29. started_thread.run_coroutine(get_thread_id())
  30. # the coroutine executed in the event loop thread
  31. assert result["thread_id"] == started_thread.ident
  32. def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread):
  33. async def raise_error():
  34. raise TypeError("bla")
  35. coro = raise_error()
  36. with pytest.raises(TypeError, match="bla"):
  37. started_thread.run_coroutine(coro)
  38. @mock.patch.object(EventLoopThread, "start")
  39. @mock.patch.object(EventLoopThread, "run_coroutine")
  40. def test_async_middleware_before_worker_boot(
  41. EventLoopThread_run_coroutine, EventLoopThread_start
  42. ):
  43. broker = mock.Mock()
  44. worker = mock.Mock()
  45. middleware = AsyncMiddleware()
  46. middleware.before_worker_boot(broker, worker)
  47. assert isinstance(middleware.event_loop_thread, EventLoopThread)
  48. EventLoopThread_start.assert_called_once()
  49. middleware.run_coroutine("foo")
  50. EventLoopThread_run_coroutine.assert_called_once_with("foo")
  51. # broker was patched with run_coroutine
  52. broker.run_coroutine("bar")
  53. EventLoopThread_run_coroutine.assert_called_with("bar")
  54. def test_async_middleware_after_worker_shutdown():
  55. broker = mock.Mock()
  56. broker.run_coroutine = lambda x: x
  57. worker = mock.Mock()
  58. event_loop_thread = mock.Mock()
  59. middleware = AsyncMiddleware()
  60. middleware.event_loop_thread = event_loop_thread
  61. middleware.after_worker_shutdown(broker, worker)
  62. event_loop_thread.join.assert_called_once()
  63. assert middleware.event_loop_thread is None
  64. assert not hasattr(broker, "run_coroutine")
  65. def test_async_actor():
  66. broker = mock.Mock()
  67. broker.actor_options = {"max_retries"}
  68. @async_actor(broker=broker)
  69. async def foo(*args, **kwargs):
  70. pass
  71. assert isinstance(foo, AsyncActor)
  72. foo(2, a="b")
  73. broker.run_coroutine.assert_called_once()
  74. # no recursion errors here:
  75. repr(foo)