cat_gateway/service/utilities/middleware/tracing_mw.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
//! Full Tracing and metrics middleware.
use std::{sync::LazyLock, time::Instant};
use cpu_time::ProcessTime; // ThreadTime doesn't work.
use poem::{
http::{header, HeaderMap},
web::RealIp,
Endpoint, Error, FromRequest, IntoResponse, Middleware, PathPattern, Request, Response, Result,
};
use poem_openapi::OperationId;
use prometheus::{
default_registry, register_histogram_vec, register_int_counter_vec, HistogramVec,
IntCounterVec, Registry,
};
use tracing::{error, field, Instrument, Level, Span};
use ulid::Ulid;
use uuid::Uuid;
use crate::{settings::Settings, utils::blake2b_hash::generate_uuid_string_from_data};
/// Labels for the metrics
const METRIC_LABELS: [&str; 3] = ["endpoint", "method", "status_code"];
/// Labels for the client metrics
const CLIENT_METRIC_LABELS: [&str; 2] = ["client", "status_code"];
// Prometheus Metrics maintained by the service
/// HTTP Request duration histogram.
static HTTP_REQ_DURATION_MS: LazyLock<HistogramVec> = LazyLock::new(|| {
#[allow(clippy::ignored_unit_patterns)]
register_histogram_vec!(
"http_request_duration_ms",
"Duration of HTTP requests in milliseconds",
&METRIC_LABELS
)
.unwrap()
});
/// HTTP Request CPU Time histogram.
static HTTP_REQ_CPU_TIME_MS: LazyLock<HistogramVec> = LazyLock::new(|| {
#[allow(clippy::ignored_unit_patterns)]
register_histogram_vec!(
"http_request_cpu_time_ms",
"CPU Time of HTTP requests in milliseconds",
&METRIC_LABELS
)
.unwrap()
});
// No Tacho implemented to enable this.
// static ref HTTP_REQUEST_RATE: GaugeVec = register_gauge_vec!(
// "http_request_rate",
// "Rate of HTTP requests per second",
// &METRIC_LABELS
// )
// .unwrap();
/// HTTP Request count histogram.
static HTTP_REQUEST_COUNT: LazyLock<IntCounterVec> = LazyLock::new(|| {
#[allow(clippy::ignored_unit_patterns)]
register_int_counter_vec!(
"http_request_count",
"Number of HTTP requests",
&METRIC_LABELS
)
.unwrap()
});
/// Client Request Count histogram.
static CLIENT_REQUEST_COUNT: LazyLock<IntCounterVec> = LazyLock::new(|| {
#[allow(clippy::ignored_unit_patterns)]
register_int_counter_vec!(
"client_request_count",
"Number of HTTP requests per client",
&CLIENT_METRIC_LABELS
)
.unwrap()
});
// Currently no way to get these values. TODO.
// Panic Request Count histogram.
// static PANIC_REQUEST_COUNT: LazyLock<IntCounterVec> = LazyLock::new(|| {
// #[allow(clippy::ignored_unit_patterns)]
// register_int_counter_vec!(
// "panic_request_count",
// "Number of HTTP requests that panicked",
// &METRIC_LABELS
// )
// .unwrap()
// });
// Currently no way to get these values without reading the whole response which is BAD.
// static ref HTTP_REQUEST_SIZE_BYTES: HistogramVec = register_histogram_vec!(
// "http_request_size_bytes",
// "Size of HTTP requests in bytes",
// &METRIC_LABELS
// )
// .unwrap();
// static ref HTTP_RESPONSE_SIZE_BYTES: HistogramVec = register_histogram_vec!(
// "http_response_size_bytes",
// "Size of HTTP responses in bytes",
// &METRIC_LABELS
// )
// .unwrap();
/// Middleware for [`tracing`](https://crates.io/crates/tracing).
#[derive(Default)]
pub(crate) struct Tracing;
impl<E: Endpoint> Middleware<E> for Tracing {
type Output = TracingEndpoint<E>;
fn transform(&self, ep: E) -> Self::Output {
TracingEndpoint { inner: ep }
}
}
/// Endpoint for `Tracing` middleware.
pub(crate) struct TracingEndpoint<E> {
/// Inner endpoint
inner: E,
}
/// Given a Clients IP Address, return the anonymized version of it.
fn anonymize_ip_address(remote_addr: &str) -> String {
let addr: Vec<String> = vec![remote_addr.to_string()];
generate_uuid_string_from_data(Settings::client_id_key(), &addr)
}
/// Get an anonymized client ID from the request.
///
/// This simply takes the clients IP address,
/// adds a supplied key to it, and hashes the result.
///
/// The Hash is unique per client IP, but not able to
/// be reversed or analyzed without both the client IP and the key.
async fn anonymous_client_id(req: &Request) -> String {
let remote_addr = RealIp::from_request_without_body(req)
.await
.ok()
.and_then(|real_ip| real_ip.0)
.map_or_else(|| req.remote_addr().to_string(), |addr| addr.to_string());
anonymize_ip_address(&remote_addr)
}
/// Data we collected about the response
struct ResponseData {
/// Duration of the request
duration: f64,
/// CPU time of the request
cpu_time: f64,
/// Status code returned
status_code: u16,
/// Endpoint name
endpoint: String,
// panic: bool,
}
impl ResponseData {
/// Create a new `ResponseData` set from the response.
/// In the process add relevant data to the span from the response.
fn new(
duration: f64, cpu_time: f64, resp: &Response, panic: Option<Uuid>, span: &Span,
) -> Self {
// The OpenAPI Operation ID of this request.
let oid = resp
.data::<OperationId>()
.map_or("Unknown".to_string(), std::string::ToString::to_string);
let status = resp.status().as_u16();
// Get the endpoint (path pattern) (this prevents metrics explosion).
let endpoint = resp.data::<PathPattern>();
let endpoint = endpoint.map_or("Unknown".to_string(), |endpoint| {
// For some reason path patterns can have trailing slashes, so remove them.
endpoint.0.trim_end_matches('/').to_string()
});
// Distinguish between "internal" endpoints and truly unknown endpoints.
span.record("duration_ms", duration);
span.record("cpu_time_ms", cpu_time);
span.record("oid", &oid);
span.record("status", status);
span.record("endpoint", &endpoint);
// Record the panic field in the span if it was set.
if let Some(panic) = panic {
span.record("panic", panic.to_string());
}
add_interesting_headers_to_span(span, "resp", resp.headers());
Self {
duration,
cpu_time,
status_code: status,
endpoint,
// panic: panic.is_some(),
}
}
}
/// Add all interesting headers to the correct fields in a span.
/// This logic is the same for both requests and responses.
fn add_interesting_headers_to_span(span: &Span, prefix: &str, headers: &HeaderMap) {
let size_field = prefix.to_string() + "_size";
let content_type_field = prefix.to_string() + "_content_type";
let encoding_field = prefix.to_string() + "_encoding";
// Record request size if its specified.
if let Some(size) = headers.get(header::CONTENT_LENGTH) {
if let Ok(size) = size.to_str() {
span.record(size_field.as_str(), size);
}
}
// Record request content type if its specified.
if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
if let Ok(content_type) = content_type.to_str() {
span.record(content_type_field.as_str(), content_type);
}
}
// Record request encoding if its specified.
if let Some(encoding) = headers.get(header::CONTENT_ENCODING) {
if let Ok(encoding) = encoding.to_str() {
span.record(encoding_field.as_str(), encoding);
}
}
}
/// Make a span from the request
async fn mk_request_span(req: &Request) -> (Span, String, String, String) {
let client_id = anonymous_client_id(req).await;
let conn_id = Ulid::new();
let uri_path = req.uri().path().to_string();
let uri_query = req.uri().query().unwrap_or("").to_string();
let method = req.method().to_string();
let span = tracing::span!(
target: "Endpoint",
Level::INFO,
"request",
client = %client_id,
conn = %conn_id,
version = ?req.version(),
method = %method,
path = %uri_path,
query_size = field::Empty,
req_size = field::Empty,
req_content_type = field::Empty,
req_encoding = field::Empty,
resp_size = field::Empty,
resp_content_type = field::Empty,
resp_encoding = field::Empty,
endpoint = field::Empty,
duration_ms = field::Empty,
cpu_time_ms = field::Empty,
oid = field::Empty,
status = field::Empty,
panic = field::Empty,
);
// Record query size (To see if we are sent enormous queries).
if !uri_query.is_empty() {
span.record("query_size", uri_query.len());
}
add_interesting_headers_to_span(&span, "req", req.headers());
// Try and get the endpoint as a path pattern (this prevents metrics explosion).
if let Some(endpoint) = req.data::<PathPattern>() {
let endpoint = endpoint.0.trim_end_matches('/').to_string();
span.record("endpoint", endpoint);
}
// Try and get the endpoint as a path pattern (this prevents metrics explosion).
if let Some(oid) = req.data::<OperationId>() {
span.record("oid", oid.0.to_string());
}
(span, uri_path, method, client_id)
}
impl<E: Endpoint> Endpoint for TracingEndpoint<E> {
type Output = Response;
async fn call(&self, req: Request) -> Result<Self::Output> {
// Construct the span from the request.
let (span, uri_path, method, client_id) = mk_request_span(&req).await;
let inner_span = span.clone();
let (response, resp_data) = async move {
let now = Instant::now();
let now_proc = ProcessTime::now();
let resp = self.inner.call(req).await;
#[allow(clippy::cast_precision_loss)] // Precision loss is acceptable for this timing.
let duration_proc = now_proc.elapsed().as_micros() as f64 / 1000.0;
#[allow(clippy::cast_precision_loss)] // Precision loss is acceptable for this timing.
let duration = now.elapsed().as_micros() as f64 / 1000.0;
match resp {
Ok(resp) => {
// let panic = if let Some(panic) = resp.downcast_ref::<ServerError>() {
// Add panic ID to the span.
// Some(panic.id());
//} else {
// None
//};
let resp = resp.into_response();
let response_data =
ResponseData::new(duration, duration_proc, &resp, None, &inner_span);
(Ok(resp), response_data)
},
Err(err) => {
// let panic = if let Some(panic) = err.downcast_ref::<ServerError>() {
// Add panic ID to the span.
// Some(panic.id());
//} else {
// None
//};
let panic: Option<Uuid> = None;
// Get the error message, and also for now, log the error to ensure we are
// preserving all data.
error!(err = ?err, "HTTP Response Error:");
// Only way I can see to get the message, may not be perfect.
let error_message = err.to_string();
// Convert the error into a response, so we can deal with the error
let resp = err.into_response();
let response_data =
ResponseData::new(duration, duration_proc, &resp, panic, &inner_span);
// Convert the response back to an error, and try and recover the message.
let mut error = Error::from_response(resp);
if !error_message.is_empty() {
error.set_error_message(&error_message);
}
(Err(error), response_data)
},
}
}
.instrument(span.clone())
.await;
span.in_scope(|| {
// We really want to use the path_pattern from the response, but if not set use the path
// from the request.
let path = if resp_data.endpoint.is_empty() {
uri_path
} else {
resp_data.endpoint
};
HTTP_REQ_DURATION_MS
.with_label_values(&[&path, &method, &resp_data.status_code.to_string()])
.observe(resp_data.duration);
HTTP_REQ_CPU_TIME_MS
.with_label_values(&[&path, &method, &resp_data.status_code.to_string()])
.observe(resp_data.cpu_time);
// HTTP_REQUEST_RATE
//.with_label_values(&[&uri_path, &method, &response.status_code.to_string()])
//.inc();
HTTP_REQUEST_COUNT
.with_label_values(&[&path, &method, &resp_data.status_code.to_string()])
.inc();
CLIENT_REQUEST_COUNT
.with_label_values(&[&client_id, &resp_data.status_code.to_string()])
.inc();
// HTTP_REQUEST_SIZE_BYTES
//.with_label_values(&[&uri_path, &method, &response.status_code.to_string()])
//.observe(response.body().len() as f64);
});
response
}
}
/// Initialize Prometheus metrics.
///
/// ## Returns
///
/// Returns the default prometheus registry.
#[must_use]
pub(crate) fn init_prometheus() -> Registry {
default_registry().clone()
}