|
@@ -6,17 +6,51 @@ from typing import Awaitable
|
|
from typing import Callable
|
|
from typing import Callable
|
|
from typing import Optional
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
from uuid import UUID
|
|
|
|
+from uuid import uuid4
|
|
|
|
|
|
import inject
|
|
import inject
|
|
from starlette.background import BackgroundTasks
|
|
from starlette.background import BackgroundTasks
|
|
from starlette.requests import Request
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.responses import Response
|
|
|
|
|
|
-from clean_python import ctx
|
|
|
|
from clean_python import Gateway
|
|
from clean_python import Gateway
|
|
from clean_python.fluentbit import FluentbitGateway
|
|
from clean_python.fluentbit import FluentbitGateway
|
|
|
|
|
|
-__all__ = ["FastAPIAccessLogger"]
|
|
|
|
|
|
+__all__ = ["FastAPIAccessLogger", "get_correlation_id"]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+CORRELATION_ID_HEADER = b"x-correlation-id"
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_view_name(request: Request) -> Optional[str]:
|
|
|
|
+ try:
|
|
|
|
+ view_name = request.scope["route"].name
|
|
|
|
+ except KeyError:
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ return view_name
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def is_health_check(request: Request) -> bool:
|
|
|
|
+ return get_view_name(request) == "health_check"
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_correlation_id(request: Request) -> Optional[UUID]:
|
|
|
|
+ headers = dict(request.scope["headers"])
|
|
|
|
+ try:
|
|
|
|
+ return UUID(headers[CORRELATION_ID_HEADER].decode())
|
|
|
|
+ except (KeyError, ValueError, UnicodeDecodeError):
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def ensure_correlation_id(request: Request) -> None:
|
|
|
|
+ correlation_id = get_correlation_id(request)
|
|
|
|
+ if correlation_id is None:
|
|
|
|
+ # generate an id and update the request inplace
|
|
|
|
+ correlation_id = uuid4()
|
|
|
|
+ headers = dict(request.scope["headers"])
|
|
|
|
+ headers[CORRELATION_ID_HEADER] = str(correlation_id).encode()
|
|
|
|
+ request.scope["headers"] = list(headers.items())
|
|
|
|
|
|
|
|
|
|
class FastAPIAccessLogger:
|
|
class FastAPIAccessLogger:
|
|
@@ -31,6 +65,11 @@ class FastAPIAccessLogger:
|
|
async def __call__(
|
|
async def __call__(
|
|
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
) -> Response:
|
|
) -> Response:
|
|
|
|
+ if request.scope["type"] != "http" or is_health_check(request):
|
|
|
|
+ return await call_next(request)
|
|
|
|
+
|
|
|
|
+ ensure_correlation_id(request)
|
|
|
|
+
|
|
time_received = time.time()
|
|
time_received = time.time()
|
|
response = await call_next(request)
|
|
response = await call_next(request)
|
|
request_time = time.time() - time_received
|
|
request_time = time.time() - time_received
|
|
@@ -46,7 +85,6 @@ class FastAPIAccessLogger:
|
|
response,
|
|
response,
|
|
time_received,
|
|
time_received,
|
|
request_time,
|
|
request_time,
|
|
- ctx.correlation_id,
|
|
|
|
)
|
|
)
|
|
return response
|
|
return response
|
|
|
|
|
|
@@ -57,7 +95,6 @@ async def log_access(
|
|
response: Response,
|
|
response: Response,
|
|
time_received: float,
|
|
time_received: float,
|
|
request_time: float,
|
|
request_time: float,
|
|
- correlation_id: Optional[UUID] = None,
|
|
|
|
) -> None:
|
|
) -> None:
|
|
"""
|
|
"""
|
|
Create a dictionary with logging data.
|
|
Create a dictionary with logging data.
|
|
@@ -67,11 +104,6 @@ async def log_access(
|
|
except (TypeError, ValueError):
|
|
except (TypeError, ValueError):
|
|
content_length = None
|
|
content_length = None
|
|
|
|
|
|
- try:
|
|
|
|
- view_name = request.scope["route"].name
|
|
|
|
- except KeyError:
|
|
|
|
- view_name = None
|
|
|
|
-
|
|
|
|
item = {
|
|
item = {
|
|
"tag_suffix": "access_log",
|
|
"tag_suffix": "access_log",
|
|
"remote_address": getattr(request.client, "host", None),
|
|
"remote_address": getattr(request.client, "host", None),
|
|
@@ -81,12 +113,12 @@ async def log_access(
|
|
"referer": request.headers.get("referer"),
|
|
"referer": request.headers.get("referer"),
|
|
"user_agent": request.headers.get("user-agent"),
|
|
"user_agent": request.headers.get("user-agent"),
|
|
"query_params": request.url.query,
|
|
"query_params": request.url.query,
|
|
- "view_name": view_name,
|
|
|
|
|
|
+ "view_name": get_view_name(request),
|
|
"status": response.status_code,
|
|
"status": response.status_code,
|
|
"content_type": response.headers.get("content-type"),
|
|
"content_type": response.headers.get("content-type"),
|
|
"content_length": content_length,
|
|
"content_length": content_length,
|
|
"time": time_received,
|
|
"time": time_received,
|
|
"request_time": request_time,
|
|
"request_time": request_time,
|
|
- "correlation_id": str(correlation_id) if correlation_id else None,
|
|
|
|
|
|
+ "correlation_id": str(get_correlation_id(request)),
|
|
}
|
|
}
|
|
await gateway.add(item)
|
|
await gateway.add(item)
|