瀏覽代碼

Skip health check access logging (#34)

Casper van der Wel 1 年之前
父節點
當前提交
6e8845ee09

+ 3 - 1
CHANGES.md

@@ -4,7 +4,9 @@
 0.8.2 (unreleased)
 0.8.2 (unreleased)
 ------------------
 ------------------
 
 
-- Nothing changed yet.
+- Skip health check access logs.
+
+- Fix access logging of correlation id.
 
 
 
 
 0.8.1 (2023-11-06)
 0.8.1 (2023-11-06)

+ 43 - 11
clean_python/fastapi/fastapi_access_logger.py

@@ -6,17 +6,51 @@ from typing import Awaitable
 from typing import Callable
 from typing import Callable
 from typing import Optional
 from typing import Optional
 from uuid import UUID
 from uuid import UUID
+from uuid import uuid4
 
 
 import inject
 import inject
 from starlette.background import BackgroundTasks
 from starlette.background import BackgroundTasks
 from starlette.requests import Request
 from starlette.requests import Request
 from starlette.responses import Response
 from starlette.responses import Response
 
 
-from clean_python import ctx
 from clean_python import Gateway
 from clean_python import Gateway
 from clean_python.fluentbit import FluentbitGateway
 from clean_python.fluentbit import FluentbitGateway
 
 
-__all__ = ["FastAPIAccessLogger"]
+__all__ = ["FastAPIAccessLogger", "get_correlation_id"]
+
+
+CORRELATION_ID_HEADER = b"x-correlation-id"
+
+
+def get_view_name(request: Request) -> Optional[str]:
+    try:
+        view_name = request.scope["route"].name
+    except KeyError:
+        return None
+
+    return view_name
+
+
+def is_health_check(request: Request) -> bool:
+    return get_view_name(request) == "health_check"
+
+
+def get_correlation_id(request: Request) -> Optional[UUID]:
+    headers = dict(request.scope["headers"])
+    try:
+        return UUID(headers[CORRELATION_ID_HEADER].decode())
+    except (KeyError, ValueError, UnicodeDecodeError):
+        return None
+
+
+def ensure_correlation_id(request: Request) -> None:
+    correlation_id = get_correlation_id(request)
+    if correlation_id is None:
+        # generate an id and update the request inplace
+        correlation_id = uuid4()
+        headers = dict(request.scope["headers"])
+        headers[CORRELATION_ID_HEADER] = str(correlation_id).encode()
+        request.scope["headers"] = list(headers.items())
 
 
 
 
 class FastAPIAccessLogger:
 class FastAPIAccessLogger:
@@ -31,6 +65,11 @@ class FastAPIAccessLogger:
     async def __call__(
     async def __call__(
         self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
         self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
     ) -> Response:
     ) -> Response:
+        if request.scope["type"] != "http" or is_health_check(request):
+            return await call_next(request)
+
+        ensure_correlation_id(request)
+
         time_received = time.time()
         time_received = time.time()
         response = await call_next(request)
         response = await call_next(request)
         request_time = time.time() - time_received
         request_time = time.time() - time_received
@@ -46,7 +85,6 @@ class FastAPIAccessLogger:
             response,
             response,
             time_received,
             time_received,
             request_time,
             request_time,
-            ctx.correlation_id,
         )
         )
         return response
         return response
 
 
@@ -57,7 +95,6 @@ async def log_access(
     response: Response,
     response: Response,
     time_received: float,
     time_received: float,
     request_time: float,
     request_time: float,
-    correlation_id: Optional[UUID] = None,
 ) -> None:
 ) -> None:
     """
     """
     Create a dictionary with logging data.
     Create a dictionary with logging data.
@@ -67,11 +104,6 @@ async def log_access(
     except (TypeError, ValueError):
     except (TypeError, ValueError):
         content_length = None
         content_length = None
 
 
-    try:
-        view_name = request.scope["route"].name
-    except KeyError:
-        view_name = None
-
     item = {
     item = {
         "tag_suffix": "access_log",
         "tag_suffix": "access_log",
         "remote_address": getattr(request.client, "host", None),
         "remote_address": getattr(request.client, "host", None),
@@ -81,12 +113,12 @@ async def log_access(
         "referer": request.headers.get("referer"),
         "referer": request.headers.get("referer"),
         "user_agent": request.headers.get("user-agent"),
         "user_agent": request.headers.get("user-agent"),
         "query_params": request.url.query,
         "query_params": request.url.query,
-        "view_name": view_name,
+        "view_name": get_view_name(request),
         "status": response.status_code,
         "status": response.status_code,
         "content_type": response.headers.get("content-type"),
         "content_type": response.headers.get("content-type"),
         "content_length": content_length,
         "content_length": content_length,
         "time": time_received,
         "time": time_received,
         "request_time": request_time,
         "request_time": request_time,
-        "correlation_id": str(correlation_id) if correlation_id else None,
+        "correlation_id": str(get_correlation_id(request)),
     }
     }
     await gateway.add(item)
     await gateway.add(item)

+ 2 - 5
clean_python/fastapi/service.py

@@ -6,12 +6,9 @@ from typing import Dict
 from typing import List
 from typing import List
 from typing import Optional
 from typing import Optional
 from typing import Set
 from typing import Set
-from uuid import UUID
-from uuid import uuid4
 
 
 from fastapi import Depends
 from fastapi import Depends
 from fastapi import FastAPI
 from fastapi import FastAPI
-from fastapi import Header
 from fastapi import Request
 from fastapi import Request
 from fastapi.exceptions import RequestValidationError
 from fastapi.exceptions import RequestValidationError
 from starlette.types import ASGIApp
 from starlette.types import ASGIApp
@@ -36,6 +33,7 @@ from .error_responses import unauthorized_handler
 from .error_responses import validation_error_handler
 from .error_responses import validation_error_handler
 from .error_responses import ValidationErrorResponse
 from .error_responses import ValidationErrorResponse
 from .fastapi_access_logger import FastAPIAccessLogger
 from .fastapi_access_logger import FastAPIAccessLogger
+from .fastapi_access_logger import get_correlation_id
 from .resource import APIVersion
 from .resource import APIVersion
 from .resource import clean_resources
 from .resource import clean_resources
 from .resource import Resource
 from .resource import Resource
@@ -68,12 +66,11 @@ def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> Dict[str,
 async def set_context(
 async def set_context(
     request: Request,
     request: Request,
     token: Token = Depends(get_token),
     token: Token = Depends(get_token),
-    x_correlation_id: UUID = Header(default_factory=uuid4),
 ) -> None:
 ) -> None:
     ctx.path = request.url
     ctx.path = request.url
     ctx.user = token.user
     ctx.user = token.user
     ctx.tenant = token.tenant
     ctx.tenant = token.tenant
-    ctx.correlation_id = x_correlation_id
+    ctx.correlation_id = get_correlation_id(request)
 
 
 
 
 async def health_check():
 async def health_check():

+ 5 - 0
docker-compose.yaml

@@ -18,3 +18,8 @@ services:
       MINIO_ROOT_PASSWORD: cleanpython
       MINIO_ROOT_PASSWORD: cleanpython
     ports:
     ports:
       - "9000:9000"
       - "9000:9000"
+
+  fluentbit:
+    image: fluent/fluent-bit:1.9
+    ports:
+      - "24224:24224"

+ 63 - 16
tests/fastapi/test_fastapi_access_logger.py

@@ -7,9 +7,11 @@ from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.responses import JSONResponse
 from starlette.responses import StreamingResponse
 from starlette.responses import StreamingResponse
 
 
-from clean_python import ctx
 from clean_python import InMemoryGateway
 from clean_python import InMemoryGateway
 from clean_python.fastapi import FastAPIAccessLogger
 from clean_python.fastapi import FastAPIAccessLogger
+from clean_python.fastapi import get_correlation_id
+
+SOME_UUID = uuid4()
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -38,6 +40,7 @@ def req():
             (b"accept-encoding", b"gzip, deflate, br"),
             (b"accept-encoding", b"gzip, deflate, br"),
             (b"accept-language", b"en-US,en;q=0.9"),
             (b"accept-language", b"en-US,en;q=0.9"),
             (b"cookie", b"..."),
             (b"cookie", b"..."),
+            (b"x-correlation-id", str(SOME_UUID).encode()),
         ],
         ],
         "state": {},
         "state": {},
         "method": "GET",
         "method": "GET",
@@ -64,23 +67,14 @@ def response():
 @pytest.fixture
 @pytest.fixture
 def call_next(response):
 def call_next(response):
     async def func(request):
     async def func(request):
+        assert get_correlation_id(request) == SOME_UUID
         return response
         return response
 
 
     return func
     return func
 
 
 
 
-@pytest.fixture
-def correlation_id():
-    uid = uuid4()
-    ctx.correlation_id = uid
-    yield uid
-    ctx.correlation_id = None
-
-
 @mock.patch("time.time", return_value=0.0)
 @mock.patch("time.time", return_value=0.0)
-async def test_logging(
-    time, fastapi_access_logger, req, response, call_next, correlation_id
-):
+async def test_logging(time, fastapi_access_logger, req, response, call_next):
     await fastapi_access_logger(req, call_next)
     await fastapi_access_logger(req, call_next)
     assert len(fastapi_access_logger.gateway.data) == 0
     assert len(fastapi_access_logger.gateway.data) == 0
     await response.background()
     await response.background()
@@ -101,7 +95,7 @@ async def test_logging(
         "content_length": 13,
         "content_length": 13,
         "time": 0.0,
         "time": 0.0,
         "request_time": 0.0,
         "request_time": 0.0,
-        "correlation_id": str(correlation_id),
+        "correlation_id": str(SOME_UUID),
     }
     }
 
 
 
 
@@ -116,7 +110,7 @@ def req_minimal():
         "scheme": "http",
         "scheme": "http",
         "path": "/",
         "path": "/",
         "query_string": "",
         "query_string": "",
-        "headers": [],
+        "headers": [(b"abc", b"def")],
     }
     }
     return Request(scope)
     return Request(scope)
 
 
@@ -135,16 +129,27 @@ def streaming_response():
 @pytest.fixture
 @pytest.fixture
 def call_next_streaming(streaming_response):
 def call_next_streaming(streaming_response):
     async def func(request):
     async def func(request):
+        assert get_correlation_id(request) == SOME_UUID
         return streaming_response
         return streaming_response
 
 
     return func
     return func
 
 
 
 
 @mock.patch("time.time", return_value=0.0)
 @mock.patch("time.time", return_value=0.0)
+@mock.patch("clean_python.fastapi.fastapi_access_logger.uuid4", return_value=SOME_UUID)
 async def test_logging_minimal(
 async def test_logging_minimal(
-    time, fastapi_access_logger, req_minimal, streaming_response, call_next_streaming
+    time,
+    uuid4,
+    fastapi_access_logger,
+    req_minimal,
+    streaming_response,
+    call_next_streaming,
 ):
 ):
     await fastapi_access_logger(req_minimal, call_next_streaming)
     await fastapi_access_logger(req_minimal, call_next_streaming)
+    assert req_minimal["headers"] == [
+        (b"abc", b"def"),
+        (b"x-correlation-id", str(SOME_UUID).encode()),
+    ]
     assert len(fastapi_access_logger.gateway.data) == 0
     assert len(fastapi_access_logger.gateway.data) == 0
     await streaming_response.background()
     await streaming_response.background()
     (actual,) = fastapi_access_logger.gateway.data.values()
     (actual,) = fastapi_access_logger.gateway.data.values()
@@ -164,5 +169,47 @@ async def test_logging_minimal(
         "content_length": None,
         "content_length": None,
         "time": 0.0,
         "time": 0.0,
         "request_time": 0.0,
         "request_time": 0.0,
-        "correlation_id": None,
+        "correlation_id": str(SOME_UUID),
     }
     }
+
+
+@pytest.fixture
+def req_health():
+    scope = {
+        "type": "http",
+        "asgi": {"version": "3.0"},
+        "http_version": "1.1",
+        "method": "GET",
+        "scheme": "http",
+        "path": "/",
+        "query_string": "",
+        "headers": [],
+        "route": APIRoute(
+            endpoint=lambda x: x,
+            path="/health",
+            name="health_check",
+            methods=["GET"],
+        ),
+    }
+    return Request(scope)
+
+
+@pytest.fixture
+def call_next_no_correlation_id(response):
+    async def func(request):
+        assert get_correlation_id(request) is None
+        return response
+
+    return func
+
+
+@mock.patch("time.time", return_value=0.0)
+async def test_logging_health_check_skipped(
+    time,
+    fastapi_access_logger,
+    req_health,
+    streaming_response,
+    call_next_no_correlation_id,
+):
+    await fastapi_access_logger(req_health, call_next_no_correlation_id)
+    assert streaming_response.background is None