test_int_celery.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import time
  2. from uuid import UUID
  3. import pytest
  4. from celery.exceptions import Ignore
  5. from celery.exceptions import Reject
  6. from clean_python import ctx
  7. from clean_python import InMemorySyncGateway
  8. from clean_python import Tenant
  9. from clean_python.celery import BaseTask
  10. from clean_python.celery import CeleryTaskLogger
  11. from clean_python.celery import set_task_logger
  12. @pytest.fixture(scope="session")
  13. def celery_parameters():
  14. return {"task_cls": BaseTask}
  15. @pytest.fixture(scope="session")
  16. def celery_worker_parameters():
  17. return {"shutdown_timeout": 10}
  18. @pytest.fixture
  19. def celery_task(celery_app, celery_worker):
  20. @celery_app.task(bind=True, base=BaseTask, name="testing")
  21. def sleep_task(self: BaseTask, seconds: float, return_value=None, event="success"):
  22. event = event.lower()
  23. if event == "success":
  24. time.sleep(int(seconds))
  25. elif event == "crash":
  26. import ctypes
  27. ctypes.string_at(0) # segfault
  28. elif event == "ignore":
  29. raise Ignore()
  30. elif event == "reject":
  31. raise Reject()
  32. elif event == "retry":
  33. raise self.retry(countdown=seconds, max_retries=1)
  34. elif event == "context":
  35. return {"tenant": ctx.tenant.id, "correlation_id": str(ctx.correlation_id)}
  36. else:
  37. raise ValueError(f"Unknown event '{event}'")
  38. return {"value": return_value}
  39. celery_worker.reload()
  40. return sleep_task
  41. @pytest.fixture
  42. def task_logger():
  43. logger = CeleryTaskLogger(InMemorySyncGateway([]))
  44. set_task_logger(logger)
  45. yield logger
  46. set_task_logger(None)
  47. def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger):
  48. result = celery_task.delay(0.0, return_value=16)
  49. assert result.get(timeout=10) == {"value": 16}
  50. (log,) = task_logger.gateway.filter([])
  51. assert 0.0 < (time.time() - log["time"]) < 1.0
  52. assert log["tag_suffix"] == "task_log"
  53. assert log["task_id"] == result.id
  54. assert log["state"] == "SUCCESS"
  55. assert log["name"] == "testing"
  56. assert log["duration"] > 0.0
  57. assert log["argsrepr"] == "(0.0,)"
  58. assert log["kwargsrepr"] == "{'return_value': 16}"
  59. assert log["retries"] == 0
  60. assert log["result"] == {"value": 16}
  61. assert UUID(log["correlation_id"]) # generated
  62. def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger):
  63. result = celery_task.delay(0.0, event="failure")
  64. with pytest.raises(ValueError):
  65. assert result.get(timeout=10)
  66. (log,) = task_logger.gateway.filter([])
  67. assert log["state"] == "FAILURE"
  68. assert log["result"]["traceback"].startswith("Traceback")
  69. @pytest.fixture
  70. def custom_context():
  71. ctx.correlation_id = UUID("b3089ea7-2585-43e5-a63c-ae30a6e9b5e4")
  72. ctx.tenant = Tenant(id=2, name="custom")
  73. yield ctx
  74. ctx.correlation_id = None
  75. ctx.tenant = None
  76. def test_context(celery_task: BaseTask, task_logger: CeleryTaskLogger, custom_context):
  77. result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0)
  78. custom_context.correlation_id = None
  79. custom_context.tenant = None
  80. assert result.get(timeout=10) == {
  81. "tenant": 2,
  82. "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
  83. }
  84. (log,) = task_logger.gateway.filter([])
  85. assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"