fastapi_access_logger.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # (c) Nelen & Schuurmans
  2. import os
  3. import time
  4. from datetime import datetime
  5. from typing import Awaitable
  6. from typing import Callable
  7. from typing import Optional
  8. import inject
  9. from starlette.background import BackgroundTasks
  10. from starlette.requests import Request
  11. from starlette.responses import Response
  12. from clean_python import Gateway
  13. from clean_python.fluentbit import FluentbitGateway
  14. __all__ = ["FastAPIAccessLogger"]
  15. class FastAPIAccessLogger:
  16. def __init__(self, hostname: str, gateway_override: Optional[Gateway] = None):
  17. self.origin = f"{hostname}-{os.getpid()}"
  18. self.gateway_override = gateway_override
  19. @property
  20. def gateway(self) -> Gateway:
  21. return self.gateway_override or inject.instance(FluentbitGateway)
  22. async def __call__(
  23. self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
  24. ) -> Response:
  25. time_received = time.time()
  26. response = await call_next(request)
  27. request_time = time.time() - time_received
  28. # Instead of logging directly, set it as background task so that it is
  29. # executed after the response. See https://www.starlette.io/background/.
  30. if response.background is None:
  31. response.background = BackgroundTasks()
  32. response.background.add_task(
  33. log_access, self.gateway, request, response, time_received, request_time
  34. )
  35. return response
  36. def fmt_timestamp(timestamp: float) -> str:
  37. return datetime.utcfromtimestamp(timestamp).isoformat() + "Z"
  38. async def log_access(
  39. gateway: Gateway,
  40. request: Request,
  41. response: Response,
  42. time_received: float,
  43. request_time: float,
  44. ) -> None:
  45. """
  46. Create a dictionary with logging data.
  47. """
  48. try:
  49. content_length = int(response.headers.get("content-length"))
  50. except (TypeError, ValueError):
  51. content_length = None
  52. try:
  53. view_name = request.scope["route"].name
  54. except KeyError:
  55. view_name = None
  56. item = {
  57. "tag_suffix": "access_log",
  58. "remote_address": getattr(request.client, "host", None),
  59. "method": request.method,
  60. "path": request.url.path,
  61. "portal": request.url.netloc,
  62. "referer": request.headers.get("referer"),
  63. "user_agent": request.headers.get("user-agent"),
  64. "query_params": request.url.query,
  65. "view_name": view_name,
  66. "status": response.status_code,
  67. "content_type": response.headers.get("content-type"),
  68. "content_length": content_length,
  69. "time": fmt_timestamp(time_received),
  70. "request_time": request_time,
  71. }
  72. await gateway.add(item)