fastapi_access_logger.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # (c) Nelen & Schuurmans
  2. import os
  3. import time
  4. from typing import Awaitable
  5. from typing import Callable
  6. from typing import Optional
  7. from uuid import UUID
  8. from uuid import uuid4
  9. import inject
  10. from starlette.background import BackgroundTasks
  11. from starlette.requests import Request
  12. from starlette.responses import Response
  13. from clean_python import Gateway
  14. from clean_python.fluentbit import FluentbitGateway
  15. __all__ = ["FastAPIAccessLogger", "get_correlation_id"]
  16. CORRELATION_ID_HEADER = b"x-correlation-id"
  17. def get_view_name(request: Request) -> Optional[str]:
  18. try:
  19. view_name = request.scope["route"].name
  20. except KeyError:
  21. return None
  22. return view_name
  23. def is_health_check(request: Request) -> bool:
  24. return get_view_name(request) == "health_check"
  25. def get_correlation_id(request: Request) -> Optional[UUID]:
  26. headers = dict(request.scope["headers"])
  27. try:
  28. return UUID(headers[CORRELATION_ID_HEADER].decode())
  29. except (KeyError, ValueError, UnicodeDecodeError):
  30. return None
  31. def ensure_correlation_id(request: Request) -> None:
  32. correlation_id = get_correlation_id(request)
  33. if correlation_id is None:
  34. # generate an id and update the request inplace
  35. correlation_id = uuid4()
  36. headers = dict(request.scope["headers"])
  37. headers[CORRELATION_ID_HEADER] = str(correlation_id).encode()
  38. request.scope["headers"] = list(headers.items())
  39. class FastAPIAccessLogger:
  40. def __init__(self, hostname: str, gateway_override: Optional[Gateway] = None):
  41. self.origin = f"{hostname}-{os.getpid()}"
  42. self.gateway_override = gateway_override
  43. @property
  44. def gateway(self) -> Gateway:
  45. return self.gateway_override or inject.instance(FluentbitGateway)
  46. async def __call__(
  47. self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
  48. ) -> Response:
  49. if request.scope["type"] != "http" or is_health_check(request):
  50. return await call_next(request)
  51. ensure_correlation_id(request)
  52. time_received = time.time()
  53. response = await call_next(request)
  54. request_time = time.time() - time_received
  55. # Instead of logging directly, set it as background task so that it is
  56. # executed after the response. See https://www.starlette.io/background/.
  57. if response.background is None:
  58. response.background = BackgroundTasks()
  59. response.background.add_task(
  60. log_access,
  61. self.gateway,
  62. request,
  63. response,
  64. time_received,
  65. request_time,
  66. )
  67. return response
  68. async def log_access(
  69. gateway: Gateway,
  70. request: Request,
  71. response: Response,
  72. time_received: float,
  73. request_time: float,
  74. ) -> None:
  75. """
  76. Create a dictionary with logging data.
  77. """
  78. try:
  79. content_length = int(response.headers.get("content-length"))
  80. except (TypeError, ValueError):
  81. content_length = None
  82. item = {
  83. "tag_suffix": "access_log",
  84. "remote_address": getattr(request.client, "host", None),
  85. "method": request.method,
  86. "path": request.url.path,
  87. "portal": request.url.netloc,
  88. "referer": request.headers.get("referer"),
  89. "user_agent": request.headers.get("user-agent"),
  90. "query_params": request.url.query,
  91. "view_name": get_view_name(request),
  92. "status": response.status_code,
  93. "content_type": response.headers.get("content-type"),
  94. "content_length": content_length,
  95. "time": time_received,
  96. "request_time": request_time,
  97. "correlation_id": str(get_correlation_id(request)),
  98. }
  99. await gateway.add(item)