test_int_celery.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import time
  2. from uuid import UUID
  3. import pytest
  4. from celery.exceptions import Ignore
  5. from celery.exceptions import Reject
  6. from clean_python import ctx
  7. from clean_python import InMemorySyncGateway
  8. from clean_python import Tenant
  9. from clean_python.celery import BaseTask
  10. from clean_python.celery import CeleryTaskLogger
  11. from clean_python.celery import set_task_logger
  12. @pytest.fixture(scope="session")
  13. def celery_parameters():
  14. return {"task_cls": BaseTask, "strict_typing": False}
  15. @pytest.fixture(scope="session")
  16. def celery_worker_parameters():
  17. return {"shutdown_timeout": 10}
  18. @pytest.fixture
  19. def celery_task(celery_app, celery_worker):
  20. @celery_app.task(bind=True, base=BaseTask, name="testing")
  21. def sleep_task(self: BaseTask, seconds: float, return_value=None, event="success"):
  22. event = event.lower()
  23. if event == "success":
  24. time.sleep(int(seconds))
  25. elif event == "crash":
  26. import ctypes
  27. ctypes.string_at(0) # segfault
  28. elif event == "ignore":
  29. raise Ignore()
  30. elif event == "reject":
  31. raise Reject()
  32. elif event == "retry":
  33. raise self.retry(countdown=seconds, max_retries=1)
  34. elif event == "context":
  35. return {
  36. "tenant_id": ctx.tenant.id,
  37. "correlation_id": str(ctx.correlation_id),
  38. }
  39. else:
  40. raise ValueError(f"Unknown event '{event}'")
  41. return {"value": return_value}
  42. celery_worker.reload()
  43. return sleep_task
  44. @pytest.fixture
  45. def task_logger():
  46. logger = CeleryTaskLogger(InMemorySyncGateway([]))
  47. set_task_logger(logger)
  48. yield logger
  49. set_task_logger(None)
  50. def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger):
  51. result = celery_task.delay(0.0, return_value=16)
  52. assert result.get(timeout=10) == {"value": 16}
  53. (log,) = task_logger.gateway.filter([])
  54. assert 0.0 < (time.time() - log["time"]) < 1.0
  55. assert log["tag_suffix"] == "task_log"
  56. assert log["task_id"] == result.id
  57. assert log["state"] == "SUCCESS"
  58. assert log["name"] == "testing"
  59. assert log["duration"] > 0.0
  60. assert log["args"] == [0.0]
  61. assert log["kwargs"] == {"return_value": 16}
  62. assert log["retries"] == 0
  63. assert log["result"] == {"value": 16}
  64. assert UUID(log["correlation_id"]) # generated
  65. assert log["tenant_id"] is None
  66. def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger):
  67. result = celery_task.delay(0.0, event="failure")
  68. with pytest.raises(ValueError):
  69. assert result.get(timeout=10)
  70. (log,) = task_logger.gateway.filter([])
  71. assert log["state"] == "FAILURE"
  72. assert log["result"]["traceback"].startswith("Traceback")
  73. @pytest.fixture
  74. def custom_context():
  75. ctx.correlation_id = UUID("b3089ea7-2585-43e5-a63c-ae30a6e9b5e4")
  76. ctx.tenant = Tenant(id=2, name="custom")
  77. yield ctx
  78. ctx.correlation_id = None
  79. ctx.tenant = None
  80. def test_context(celery_task: BaseTask, custom_context, task_logger):
  81. result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0)
  82. assert result.get(timeout=10) == {
  83. "tenant_id": 2,
  84. "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
  85. }
  86. (log,) = task_logger.gateway.filter([])
  87. assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"
  88. assert log["tenant_id"] == 2