123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import json
- import time
- from uuid import UUID
- import pytest
- from celery.exceptions import Ignore
- from celery.exceptions import Reject
- from clean_python import ctx
- from clean_python import InMemorySyncGateway
- from clean_python import Tenant
- from clean_python.celery import BaseTask
- from clean_python.celery import CeleryTaskLogger
- from clean_python.celery import set_task_logger
- @pytest.fixture(scope="session")
- def celery_parameters():
- return {"task_cls": BaseTask, "strict_typing": False}
- @pytest.fixture(scope="session")
- def celery_worker_parameters():
- return {"shutdown_timeout": 10}
- @pytest.fixture
- def celery_task(celery_app, celery_worker):
- @celery_app.task(bind=True, base=BaseTask, name="testing")
- def sleep_task(self: BaseTask, seconds: float, return_value=None, event="success"):
- event = event.lower()
- if event == "success":
- time.sleep(int(seconds))
- elif event == "crash":
- import ctypes
- ctypes.string_at(0) # segfault
- elif event == "ignore":
- raise Ignore()
- elif event == "reject":
- raise Reject()
- elif event == "retry":
- raise self.retry(countdown=seconds, max_retries=1)
- elif event == "context":
- return {
- "tenant_id": ctx.tenant.id,
- "correlation_id": str(ctx.correlation_id),
- }
- else:
- raise ValueError(f"Unknown event '{event}'")
- return {"value": return_value}
- celery_worker.reload()
- return sleep_task
- @pytest.fixture
- def task_logger():
- logger = CeleryTaskLogger(InMemorySyncGateway([]))
- set_task_logger(logger)
- yield logger
- set_task_logger(None)
- def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger):
- result = celery_task.delay(0.0, return_value=16)
- assert result.get(timeout=10) == {"value": 16}
- (log,) = task_logger.gateway.filter([])
- assert 0.0 < (time.time() - log["time"]) < 1.0
- assert log["tag_suffix"] == "task_log"
- assert log["task_id"] == result.id
- assert log["state"] == "SUCCESS"
- assert log["name"] == "testing"
- assert log["duration"] > 0.0
- assert json.loads(log["argsrepr"]) == [0.0]
- assert json.loads(log["kwargsrepr"]) == {"return_value": 16}
- assert log["retries"] == 0
- assert log["result"] == {"value": 16}
- assert UUID(log["correlation_id"]) # generated
- assert log["tenant_id"] is None
- def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger):
- result = celery_task.delay(0.0, event="failure")
- with pytest.raises(ValueError):
- assert result.get(timeout=10)
- (log,) = task_logger.gateway.filter([])
- assert log["state"] == "FAILURE"
- assert log["result"]["traceback"].startswith("Traceback")
- @pytest.fixture
- def custom_context():
- ctx.correlation_id = UUID("b3089ea7-2585-43e5-a63c-ae30a6e9b5e4")
- ctx.tenant = Tenant(id=2, name="custom")
- yield ctx
- ctx.correlation_id = None
- ctx.tenant = None
- def test_context(celery_task: BaseTask, custom_context, task_logger):
- result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0)
- assert result.get(timeout=10) == {
- "tenant_id": 2,
- "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
- }
- (log,) = task_logger.gateway.filter([])
- assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"
- assert log["tenant_id"] == 2
|