test_celery_base_task.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from unittest import mock
  2. from uuid import UUID
  3. from uuid import uuid4
  4. import pytest
  5. from celery import Task
  6. from clean_python import ctx
  7. from clean_python import Tenant
  8. from clean_python.celery import BaseTask
  9. from clean_python.celery.base_task import HEADER_FIELD
  10. @pytest.fixture
  11. def mocked_apply_async():
  12. with mock.patch.object(Task, "apply_async") as m:
  13. yield m
  14. @pytest.fixture
  15. def temp_context():
  16. ctx.tenant = Tenant(id=2, name="test")
  17. ctx.correlation_id = uuid4()
  18. yield ctx
  19. ctx.tenant = None
  20. ctx.correlation_id = None
  21. def test_apply_async(mocked_apply_async):
  22. BaseTask().apply_async(args="foo", kwargs="bar")
  23. assert mocked_apply_async.call_count == 1
  24. args, kwargs = mocked_apply_async.call_args
  25. assert args == ("foo", "bar")
  26. assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
  27. UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"]) # generated
  28. def test_apply_async_with_context(mocked_apply_async, temp_context):
  29. BaseTask().apply_async(args="foo", kwargs="bar")
  30. assert mocked_apply_async.call_count == 1
  31. _, kwargs = mocked_apply_async.call_args
  32. assert kwargs["headers"][HEADER_FIELD]["tenant"] == temp_context.tenant.model_dump(
  33. mode="json"
  34. )
  35. kwargs["headers"][HEADER_FIELD]["correlation_id"] == str(
  36. temp_context.correlation_id
  37. )
  38. def test_apply_async_headers_extended(mocked_apply_async):
  39. headers = {"baz": 2}
  40. BaseTask().apply_async(args="foo", kwargs="bar", headers=headers)
  41. assert mocked_apply_async.call_count == 1
  42. _, kwargs = mocked_apply_async.call_args
  43. assert kwargs["headers"]["baz"] == 2
  44. assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
  45. UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"]) # generated
  46. assert headers == {"baz": 2} # not changed inplace
  47. def test_apply_async_headers_already_present(mocked_apply_async):
  48. BaseTask().apply_async(args="foo", kwargs="bar", headers={HEADER_FIELD: "foo"})
  49. assert mocked_apply_async.call_count == 1
  50. _, kwargs = mocked_apply_async.call_args
  51. assert kwargs["headers"] == {HEADER_FIELD: "foo"}