Переглянути джерело

Fix issues with logging contextvars (#31)

Casper van der Wel 1 рік тому
батько
коміт
cba405a1b1

+ 3 - 1
CHANGES.md

@@ -4,7 +4,9 @@
 ## 0.7.1 (unreleased)
 ---------------------
 
-- Nothing changed yet.
+- Automatically dump and restore correlation_id in dramatiq actors.
+
+- Fixed logging of correlation_id in fastapi access logger.
 
 
 ## 0.7.0 (2023-11-01)

+ 13 - 1
clean_python/dramatiq/dramatiq_task_logger.py

@@ -4,13 +4,15 @@ import os
 import threading
 import time
 from typing import Optional
+from uuid import UUID
+from uuid import uuid4
 
 import inject
 from dramatiq import get_encoder
+from dramatiq import Message
 from dramatiq import Middleware
 from dramatiq.errors import RateLimitExceeded
 from dramatiq.errors import Retry
-from dramatiq.message import Message
 from dramatiq.middleware import SkipMessage
 
 from clean_python import ctx
@@ -24,14 +26,24 @@ class AsyncLoggingMiddleware(Middleware):
     def __init__(self, **kwargs):
         self.logger = DramatiqTaskLogger(**kwargs)
 
+    def before_enqueue(self, broker, message: Message, delay):
+        if ctx.correlation_id is not None:
+            message.options["correlation_id"] = str(ctx.correlation_id)
+
     def before_process_message(self, broker, message):
+        if message.options.get("correlation_id") is not None:
+            ctx.correlation_id = UUID(message.options["correlation_id"])
+        else:
+            ctx.correlation_id = uuid4()
         broker.run_coroutine(self.logger.start())
 
     def after_skip_message(self, broker, message):
         broker.run_coroutine(self.logger.stop(message, None, SkipMessage()))
+        ctx.correlation_id = None
 
     def after_process_message(self, broker, message, *, result=None, exception=None):
         broker.run_coroutine(self.logger.stop(message, result, exception))
+        ctx.correlation_id = None
 
 
 class DramatiqTaskLogger:

+ 10 - 2
clean_python/fastapi/fastapi_access_logger.py

@@ -5,6 +5,7 @@ import time
 from typing import Awaitable
 from typing import Callable
 from typing import Optional
+from uuid import UUID
 
 import inject
 from starlette.background import BackgroundTasks
@@ -39,7 +40,13 @@ class FastAPIAccessLogger:
         if response.background is None:
             response.background = BackgroundTasks()
         response.background.add_task(
-            log_access, self.gateway, request, response, time_received, request_time
+            log_access,
+            self.gateway,
+            request,
+            response,
+            time_received,
+            request_time,
+            ctx.correlation_id,
         )
         return response
 
@@ -50,6 +57,7 @@ async def log_access(
     response: Response,
     time_received: float,
     request_time: float,
+    correlation_id: Optional[UUID] = None,
 ) -> None:
     """
     Create a dictionary with logging data.
@@ -79,6 +87,6 @@ async def log_access(
         "content_length": content_length,
         "time": time_received,
         "request_time": request_time,
-        "correlation_id": str(ctx.correlation_id) if ctx.correlation_id else None,
+        "correlation_id": str(correlation_id) if correlation_id else None,
     }
     await gateway.add(item)

+ 13 - 0
tests/test_async_actor.py

@@ -1,5 +1,6 @@
 import threading
 from asyncio import BaseEventLoop
+from contextvars import ContextVar
 from unittest import mock
 
 import pytest
@@ -104,3 +105,15 @@ def test_async_actor():
 
     # no recursion errors here:
     repr(foo)
+
+
+foo_var: ContextVar[int] = ContextVar("foo", default=42)
+
+
+def test_run_coroutine_keeps_context(started_thread: EventLoopThread):
+    async def return_foo_var():
+        return foo_var.get()
+
+    foo_var.set(31)
+
+    assert started_thread.run_coroutine(return_foo_var()) == 31