test_int_celery.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import json
  2. import time
  3. from uuid import UUID
  4. import pytest
  5. from celery.exceptions import Ignore
  6. from celery.exceptions import Reject
  7. from clean_python import ctx
  8. from clean_python import InMemorySyncGateway
  9. from clean_python import Tenant
  10. from clean_python.celery import BaseTask
  11. from clean_python.celery import CeleryTaskLogger
  12. from clean_python.celery import set_task_logger
  13. @pytest.fixture(scope="session")
  14. def celery_parameters():
  15. return {"task_cls": BaseTask, "strict_typing": False}
  16. @pytest.fixture(scope="session")
  17. def celery_worker_parameters():
  18. return {"shutdown_timeout": 10}
  19. @pytest.fixture
  20. def celery_task(celery_app, celery_worker):
  21. @celery_app.task(bind=True, base=BaseTask, name="testing")
  22. def sleep_task(self: BaseTask, seconds: float, return_value=None, event="success"):
  23. event = event.lower()
  24. if event == "success":
  25. time.sleep(int(seconds))
  26. elif event == "crash":
  27. import ctypes
  28. ctypes.string_at(0) # segfault
  29. elif event == "ignore":
  30. raise Ignore()
  31. elif event == "reject":
  32. raise Reject()
  33. elif event == "retry":
  34. raise self.retry(countdown=seconds, max_retries=1)
  35. elif event == "context":
  36. return {
  37. "tenant_id": ctx.tenant.id,
  38. "correlation_id": str(ctx.correlation_id),
  39. }
  40. else:
  41. raise ValueError(f"Unknown event '{event}'")
  42. return {"value": return_value}
  43. celery_worker.reload()
  44. return sleep_task
  45. @pytest.fixture
  46. def task_logger():
  47. logger = CeleryTaskLogger(InMemorySyncGateway([]))
  48. set_task_logger(logger)
  49. yield logger
  50. set_task_logger(None)
  51. def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger):
  52. result = celery_task.delay(0.0, return_value=16)
  53. assert result.get(timeout=10) == {"value": 16}
  54. (log,) = task_logger.gateway.filter([])
  55. assert 0.0 < (time.time() - log["time"]) < 1.0
  56. assert log["tag_suffix"] == "task_log"
  57. assert log["task_id"] == result.id
  58. assert log["state"] == "SUCCESS"
  59. assert log["name"] == "testing"
  60. assert log["duration"] > 0.0
  61. assert json.loads(log["argsrepr"]) == [0.0]
  62. assert json.loads(log["kwargsrepr"]) == {"return_value": 16}
  63. assert log["retries"] == 0
  64. assert log["result"] == {"value": 16}
  65. assert UUID(log["correlation_id"]) # generated
  66. assert log["tenant_id"] is None
  67. def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger):
  68. result = celery_task.delay(0.0, event="failure")
  69. with pytest.raises(ValueError):
  70. assert result.get(timeout=10)
  71. (log,) = task_logger.gateway.filter([])
  72. assert log["state"] == "FAILURE"
  73. assert log["result"]["traceback"].startswith("Traceback")
  74. @pytest.fixture
  75. def custom_context():
  76. ctx.correlation_id = UUID("b3089ea7-2585-43e5-a63c-ae30a6e9b5e4")
  77. ctx.tenant = Tenant(id=2, name="custom")
  78. yield ctx
  79. ctx.correlation_id = None
  80. ctx.tenant = None
  81. def test_context(celery_task: BaseTask, custom_context, task_logger):
  82. result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0)
  83. assert result.get(timeout=10) == {
  84. "tenant_id": 2,
  85. "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
  86. }
  87. (log,) = task_logger.gateway.filter([])
  88. assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"
  89. assert log["tenant_id"] == 2