Kaynağa Gözat

Skip health check access logging (#34)

Casper van der Wel 1 yıl önce
ebeveyn
işleme
6e8845ee09

+ 3 - 1
CHANGES.md

@@ -4,7 +4,9 @@
 0.8.2 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Skip health check access logs.
+
+- Fix access logging of correlation id.
 
 
 0.8.1 (2023-11-06)

+ 43 - 11
clean_python/fastapi/fastapi_access_logger.py

@@ -6,17 +6,51 @@ from typing import Awaitable
 from typing import Callable
 from typing import Optional
 from uuid import UUID
+from uuid import uuid4
 
 import inject
 from starlette.background import BackgroundTasks
 from starlette.requests import Request
 from starlette.responses import Response
 
-from clean_python import ctx
 from clean_python import Gateway
 from clean_python.fluentbit import FluentbitGateway
 
-__all__ = ["FastAPIAccessLogger"]
+__all__ = ["FastAPIAccessLogger", "get_correlation_id"]
+
+
+CORRELATION_ID_HEADER = b"x-correlation-id"
+
+
+def get_view_name(request: Request) -> Optional[str]:
+    try:
+        view_name = request.scope["route"].name
+    except KeyError:
+        return None
+
+    return view_name
+
+
+def is_health_check(request: Request) -> bool:
+    return get_view_name(request) == "health_check"
+
+
+def get_correlation_id(request: Request) -> Optional[UUID]:
+    headers = dict(request.scope["headers"])
+    try:
+        return UUID(headers[CORRELATION_ID_HEADER].decode())
+    except (KeyError, ValueError, UnicodeDecodeError):
+        return None
+
+
+def ensure_correlation_id(request: Request) -> None:
+    correlation_id = get_correlation_id(request)
+    if correlation_id is None:
+        # generate an id and update the request inplace
+        correlation_id = uuid4()
+        headers = dict(request.scope["headers"])
+        headers[CORRELATION_ID_HEADER] = str(correlation_id).encode()
+        request.scope["headers"] = list(headers.items())
 
 
 class FastAPIAccessLogger:
@@ -31,6 +65,11 @@ class FastAPIAccessLogger:
     async def __call__(
         self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
     ) -> Response:
+        if request.scope["type"] != "http" or is_health_check(request):
+            return await call_next(request)
+
+        ensure_correlation_id(request)
+
         time_received = time.time()
         response = await call_next(request)
         request_time = time.time() - time_received
@@ -46,7 +85,6 @@ class FastAPIAccessLogger:
             response,
             time_received,
             request_time,
-            ctx.correlation_id,
         )
         return response
 
@@ -57,7 +95,6 @@ async def log_access(
     response: Response,
     time_received: float,
     request_time: float,
-    correlation_id: Optional[UUID] = None,
 ) -> None:
     """
     Create a dictionary with logging data.
@@ -67,11 +104,6 @@ async def log_access(
     except (TypeError, ValueError):
         content_length = None
 
-    try:
-        view_name = request.scope["route"].name
-    except KeyError:
-        view_name = None
-
     item = {
         "tag_suffix": "access_log",
         "remote_address": getattr(request.client, "host", None),
@@ -81,12 +113,12 @@ async def log_access(
         "referer": request.headers.get("referer"),
         "user_agent": request.headers.get("user-agent"),
         "query_params": request.url.query,
-        "view_name": view_name,
+        "view_name": get_view_name(request),
         "status": response.status_code,
         "content_type": response.headers.get("content-type"),
         "content_length": content_length,
         "time": time_received,
         "request_time": request_time,
-        "correlation_id": str(correlation_id) if correlation_id else None,
+        "correlation_id": str(get_correlation_id(request)),
     }
     await gateway.add(item)

+ 2 - 5
clean_python/fastapi/service.py

@@ -6,12 +6,9 @@ from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Set
-from uuid import UUID
-from uuid import uuid4
 
 from fastapi import Depends
 from fastapi import FastAPI
-from fastapi import Header
 from fastapi import Request
 from fastapi.exceptions import RequestValidationError
 from starlette.types import ASGIApp
@@ -36,6 +33,7 @@ from .error_responses import unauthorized_handler
 from .error_responses import validation_error_handler
 from .error_responses import ValidationErrorResponse
 from .fastapi_access_logger import FastAPIAccessLogger
+from .fastapi_access_logger import get_correlation_id
 from .resource import APIVersion
 from .resource import clean_resources
 from .resource import Resource
@@ -68,12 +66,11 @@ def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> Dict[str,
 async def set_context(
     request: Request,
     token: Token = Depends(get_token),
-    x_correlation_id: UUID = Header(default_factory=uuid4),
 ) -> None:
     ctx.path = request.url
     ctx.user = token.user
     ctx.tenant = token.tenant
-    ctx.correlation_id = x_correlation_id
+    ctx.correlation_id = get_correlation_id(request)
 
 
 async def health_check():

+ 5 - 0
docker-compose.yaml

@@ -18,3 +18,8 @@ services:
       MINIO_ROOT_PASSWORD: cleanpython
     ports:
       - "9000:9000"
+
+  fluentbit:
+    image: fluent/fluent-bit:1.9
+    ports:
+      - "24224:24224"

+ 63 - 16
tests/fastapi/test_fastapi_access_logger.py

@@ -7,9 +7,11 @@ from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.responses import StreamingResponse
 
-from clean_python import ctx
 from clean_python import InMemoryGateway
 from clean_python.fastapi import FastAPIAccessLogger
+from clean_python.fastapi import get_correlation_id
+
+SOME_UUID = uuid4()
 
 
 @pytest.fixture
@@ -38,6 +40,7 @@ def req():
             (b"accept-encoding", b"gzip, deflate, br"),
             (b"accept-language", b"en-US,en;q=0.9"),
             (b"cookie", b"..."),
+            (b"x-correlation-id", str(SOME_UUID).encode()),
         ],
         "state": {},
         "method": "GET",
@@ -64,23 +67,14 @@ def response():
 @pytest.fixture
 def call_next(response):
     async def func(request):
+        assert get_correlation_id(request) == SOME_UUID
         return response
 
     return func
 
 
-@pytest.fixture
-def correlation_id():
-    uid = uuid4()
-    ctx.correlation_id = uid
-    yield uid
-    ctx.correlation_id = None
-
-
 @mock.patch("time.time", return_value=0.0)
-async def test_logging(
-    time, fastapi_access_logger, req, response, call_next, correlation_id
-):
+async def test_logging(time, fastapi_access_logger, req, response, call_next):
     await fastapi_access_logger(req, call_next)
     assert len(fastapi_access_logger.gateway.data) == 0
     await response.background()
@@ -101,7 +95,7 @@ async def test_logging(
         "content_length": 13,
         "time": 0.0,
         "request_time": 0.0,
-        "correlation_id": str(correlation_id),
+        "correlation_id": str(SOME_UUID),
     }
 
 
@@ -116,7 +110,7 @@ def req_minimal():
         "scheme": "http",
         "path": "/",
         "query_string": "",
-        "headers": [],
+        "headers": [(b"abc", b"def")],
     }
     return Request(scope)
 
@@ -135,16 +129,27 @@ def streaming_response():
 @pytest.fixture
 def call_next_streaming(streaming_response):
     async def func(request):
+        assert get_correlation_id(request) == SOME_UUID
         return streaming_response
 
     return func
 
 
 @mock.patch("time.time", return_value=0.0)
+@mock.patch("clean_python.fastapi.fastapi_access_logger.uuid4", return_value=SOME_UUID)
 async def test_logging_minimal(
-    time, fastapi_access_logger, req_minimal, streaming_response, call_next_streaming
+    time,
+    uuid4,
+    fastapi_access_logger,
+    req_minimal,
+    streaming_response,
+    call_next_streaming,
 ):
     await fastapi_access_logger(req_minimal, call_next_streaming)
+    assert req_minimal["headers"] == [
+        (b"abc", b"def"),
+        (b"x-correlation-id", str(SOME_UUID).encode()),
+    ]
     assert len(fastapi_access_logger.gateway.data) == 0
     await streaming_response.background()
     (actual,) = fastapi_access_logger.gateway.data.values()
@@ -164,5 +169,47 @@ async def test_logging_minimal(
         "content_length": None,
         "time": 0.0,
         "request_time": 0.0,
-        "correlation_id": None,
+        "correlation_id": str(SOME_UUID),
     }
+
+
+@pytest.fixture
+def req_health():
+    scope = {
+        "type": "http",
+        "asgi": {"version": "3.0"},
+        "http_version": "1.1",
+        "method": "GET",
+        "scheme": "http",
+        "path": "/",
+        "query_string": "",
+        "headers": [],
+        "route": APIRoute(
+            endpoint=lambda x: x,
+            path="/health",
+            name="health_check",
+            methods=["GET"],
+        ),
+    }
+    return Request(scope)
+
+
+@pytest.fixture
+def call_next_no_correlation_id(response):
+    async def func(request):
+        assert get_correlation_id(request) is None
+        return response
+
+    return func
+
+
+@mock.patch("time.time", return_value=0.0)
+async def test_logging_health_check_skipped(
+    time,
+    fastapi_access_logger,
+    req_health,
+    streaming_response,
+    call_next_no_correlation_id,
+):
+    await fastapi_access_logger(req_health, call_next_no_correlation_id)
+    assert streaming_response.background is None