fastapi_access_logger.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # (c) Nelen & Schuurmans
  2. import os
  3. import time
  4. from typing import Awaitable
  5. from typing import Callable
  6. from typing import Optional
  7. from uuid import UUID
  8. import inject
  9. from starlette.background import BackgroundTasks
  10. from starlette.requests import Request
  11. from starlette.responses import Response
  12. from clean_python import ctx
  13. from clean_python import Gateway
  14. from clean_python.fluentbit import FluentbitGateway
  15. __all__ = ["FastAPIAccessLogger"]
  16. class FastAPIAccessLogger:
  17. def __init__(self, hostname: str, gateway_override: Optional[Gateway] = None):
  18. self.origin = f"{hostname}-{os.getpid()}"
  19. self.gateway_override = gateway_override
  20. @property
  21. def gateway(self) -> Gateway:
  22. return self.gateway_override or inject.instance(FluentbitGateway)
  23. async def __call__(
  24. self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
  25. ) -> Response:
  26. time_received = time.time()
  27. response = await call_next(request)
  28. request_time = time.time() - time_received
  29. # Instead of logging directly, set it as background task so that it is
  30. # executed after the response. See https://www.starlette.io/background/.
  31. if response.background is None:
  32. response.background = BackgroundTasks()
  33. response.background.add_task(
  34. log_access,
  35. self.gateway,
  36. request,
  37. response,
  38. time_received,
  39. request_time,
  40. ctx.correlation_id,
  41. )
  42. return response
  43. async def log_access(
  44. gateway: Gateway,
  45. request: Request,
  46. response: Response,
  47. time_received: float,
  48. request_time: float,
  49. correlation_id: Optional[UUID] = None,
  50. ) -> None:
  51. """
  52. Create a dictionary with logging data.
  53. """
  54. try:
  55. content_length = int(response.headers.get("content-length"))
  56. except (TypeError, ValueError):
  57. content_length = None
  58. try:
  59. view_name = request.scope["route"].name
  60. except KeyError:
  61. view_name = None
  62. item = {
  63. "tag_suffix": "access_log",
  64. "remote_address": getattr(request.client, "host", None),
  65. "method": request.method,
  66. "path": request.url.path,
  67. "portal": request.url.netloc,
  68. "referer": request.headers.get("referer"),
  69. "user_agent": request.headers.get("user-agent"),
  70. "query_params": request.url.query,
  71. "view_name": view_name,
  72. "status": response.status_code,
  73. "content_type": response.headers.get("content-type"),
  74. "content_length": content_length,
  75. "time": time_received,
  76. "request_time": request_time,
  77. "correlation_id": str(correlation_id) if correlation_id else None,
  78. }
  79. await gateway.add(item)