| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 | # (c) Nelen & Schuurmansimport osimport timefrom typing import Awaitablefrom typing import Callablefrom typing import Optionalfrom uuid import UUIDfrom uuid import uuid4import injectfrom starlette.background import BackgroundTasksfrom starlette.requests import Requestfrom starlette.responses import Responsefrom clean_python import Gatewayfrom clean_python.fluentbit import FluentbitGateway__all__ = ["FastAPIAccessLogger", "get_correlation_id"]CORRELATION_ID_HEADER = b"x-correlation-id"def get_view_name(request: Request) -> Optional[str]:    try:        view_name = request.scope["route"].name    except KeyError:        return None    return view_namedef is_health_check(request: Request) -> bool:    return get_view_name(request) == "health_check"def get_correlation_id(request: Request) -> Optional[UUID]:    headers = dict(request.scope["headers"])    try:        return UUID(headers[CORRELATION_ID_HEADER].decode())    except (KeyError, ValueError, UnicodeDecodeError):        return Nonedef ensure_correlation_id(request: Request) -> None:    correlation_id = get_correlation_id(request)    if correlation_id is None:        # generate an id and update the request inplace        correlation_id = uuid4()        headers = dict(request.scope["headers"])        headers[CORRELATION_ID_HEADER] = str(correlation_id).encode()        request.scope["headers"] = list(headers.items())class FastAPIAccessLogger:    def __init__(self, hostname: str, gateway_override: Optional[Gateway] = None):        self.origin = f"{hostname}-{os.getpid()}"        self.gateway_override = gateway_override    @property    def gateway(self) -> Gateway:        return self.gateway_override or inject.instance(FluentbitGateway)    async def __call__(        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]    ) -> Response:        if request.scope["type"] != "http" or is_health_check(request):            return await call_next(request)        ensure_correlation_id(request)        time_received = time.time()        response = await call_next(request)        request_time = time.time() - time_received        # Instead of logging directly, set it as background task so that it is        # executed after the response. See https://www.starlette.io/background/.        if response.background is None:            response.background = BackgroundTasks()        response.background.add_task(            log_access,            self.gateway,            request,            response,            time_received,            request_time,        )        return responseasync def log_access(    gateway: Gateway,    request: Request,    response: Response,    time_received: float,    request_time: float,) -> None:    """    Create a dictionary with logging data.    """    try:        content_length = int(response.headers.get("content-length"))    except (TypeError, ValueError):        content_length = None    item = {        "tag_suffix": "access_log",        "remote_address": getattr(request.client, "host", None),        "method": request.method,        "path": request.url.path,        "portal": request.url.netloc,        "referer": request.headers.get("referer"),        "user_agent": request.headers.get("user-agent"),        "query_params": request.url.query,        "view_name": get_view_name(request),        "status": response.status_code,        "content_type": response.headers.get("content-type"),        "content_length": content_length,        "time": time_received,        "request_time": request_time,        "correlation_id": str(get_correlation_id(request)),    }    await gateway.add(item)
 |