cat_gateway/service/utilities/middleware/
tracing_mw.rs

1//! Full Tracing and metrics middleware.
2
3use std::time::Instant;
4
5use cpu_time::ProcessTime; // ThreadTime doesn't work.
6use poem::{
7    http::{header, HeaderMap},
8    web::RealIp,
9    Endpoint, Error, FromRequest, IntoResponse, Middleware, PathPattern, Request, Response, Result,
10};
11use poem_openapi::OperationId;
12use tracing::{error, field, Instrument, Level, Span};
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{
17    metrics::endpoint::{
18        CLIENT_REQUEST_COUNT, HTTP_REQUEST_COUNT, HTTP_REQ_CPU_TIME_MS, HTTP_REQ_DURATION_MS,
19    },
20    settings::Settings,
21    utils::blake2b_hash::generate_uuid_string_from_data,
22};
23
24// Currently no way to get these values. TODO.
25// Panic Request Count histogram.
26// static PANIC_REQUEST_COUNT: LazyLock<IntCounterVec> = LazyLock::new(|| {
27// #[allow(clippy::ignored_unit_patterns)]
28// register_int_counter_vec!(
29// "panic_request_count",
30// "Number of HTTP requests that panicked",
31// &METRIC_LABELS
32// )
33// .unwrap()
34// });
35
36// Currently no way to get these values without reading the whole response which is BAD.
37// static ref HTTP_REQUEST_SIZE_BYTES: HistogramVec = register_histogram_vec!(
38// "http_request_size_bytes",
39// "Size of HTTP requests in bytes",
40// &METRIC_LABELS
41// )
42// .unwrap();
43// static ref HTTP_RESPONSE_SIZE_BYTES: HistogramVec = register_histogram_vec!(
44// "http_response_size_bytes",
45// "Size of HTTP responses in bytes",
46// &METRIC_LABELS
47// )
48// .unwrap();
49
50/// Middleware for [`tracing`](https://crates.io/crates/tracing).
51#[derive(Default)]
52pub(crate) struct Tracing;
53
54impl<E: Endpoint> Middleware<E> for Tracing {
55    type Output = TracingEndpoint<E>;
56
57    fn transform(&self, ep: E) -> Self::Output {
58        TracingEndpoint { inner: ep }
59    }
60}
61
62/// Endpoint for `Tracing` middleware.
63pub(crate) struct TracingEndpoint<E> {
64    /// Inner endpoint
65    inner: E,
66}
67
68/// Given a Clients IP Address, return the anonymized version of it.
69fn anonymize_ip_address(remote_addr: &str) -> String {
70    let addr: Vec<String> = vec![remote_addr.to_string()];
71    generate_uuid_string_from_data(Settings::client_id_key(), &addr)
72}
73
74/// Get an anonymized client ID from the request.
75///
76/// This simply takes the clients IP address,
77/// adds a supplied key to it, and hashes the result.
78///
79/// The Hash is unique per client IP, but not able to
80/// be reversed or analyzed without both the client IP and the key.
81async fn anonymous_client_id(req: &Request) -> String {
82    let remote_addr = RealIp::from_request_without_body(req)
83        .await
84        .ok()
85        .and_then(|real_ip| real_ip.0)
86        .map_or_else(|| req.remote_addr().to_string(), |addr| addr.to_string());
87
88    anonymize_ip_address(&remote_addr)
89}
90
91/// Data we collected about the response
92struct ResponseData {
93    /// Duration of the request
94    duration: f64,
95    /// CPU time of the request
96    cpu_time: f64,
97    /// Status code returned
98    status_code: u16,
99    /// Endpoint name
100    endpoint: String,
101    // panic: bool,
102}
103
104impl ResponseData {
105    /// Create a new `ResponseData` set from the response.
106    /// In the process add relevant data to the span from the response.
107    fn new(
108        duration: f64, cpu_time: f64, resp: &Response, panic: Option<Uuid>, span: &Span,
109    ) -> Self {
110        // The OpenAPI Operation ID of this request.
111        let oid = resp
112            .data::<OperationId>()
113            .map_or("Unknown".to_string(), std::string::ToString::to_string);
114
115        let status = resp.status().as_u16();
116
117        // Get the endpoint (path pattern) (this prevents metrics explosion).
118        let endpoint = resp.data::<PathPattern>();
119        let endpoint = endpoint.map_or("Unknown".to_string(), |endpoint| {
120            // For some reason path patterns can have trailing slashes, so remove them.
121            endpoint.0.trim_end_matches('/').to_string()
122        });
123
124        // Distinguish between "internal" endpoints and truly unknown endpoints.
125
126        span.record("duration_ms", duration);
127        span.record("cpu_time_ms", cpu_time);
128        span.record("oid", &oid);
129        span.record("status", status);
130        span.record("endpoint", &endpoint);
131
132        // Record the panic field in the span if it was set.
133        if let Some(panic) = panic {
134            span.record("panic", panic.to_string());
135        }
136
137        add_interesting_headers_to_span(span, "resp", resp.headers());
138
139        Self {
140            duration,
141            cpu_time,
142            status_code: status,
143            endpoint,
144            // panic: panic.is_some(),
145        }
146    }
147}
148
149/// Add all interesting headers to the correct fields in a span.
150/// This logic is the same for both requests and responses.
151fn add_interesting_headers_to_span(span: &Span, prefix: &str, headers: &HeaderMap) {
152    let size_field = prefix.to_string() + "_size";
153    let content_type_field = prefix.to_string() + "_content_type";
154    let encoding_field = prefix.to_string() + "_encoding";
155
156    // Record request size if its specified.
157    if let Some(size) = headers.get(header::CONTENT_LENGTH) {
158        if let Ok(size) = size.to_str() {
159            span.record(size_field.as_str(), size);
160        }
161    }
162
163    // Record request content type if its specified.
164    if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
165        if let Ok(content_type) = content_type.to_str() {
166            span.record(content_type_field.as_str(), content_type);
167        }
168    }
169
170    // Record request encoding if its specified.
171    if let Some(encoding) = headers.get(header::CONTENT_ENCODING) {
172        if let Ok(encoding) = encoding.to_str() {
173            span.record(encoding_field.as_str(), encoding);
174        }
175    }
176}
177
178/// Make a span from the request
179async fn mk_request_span(req: &Request) -> (Span, String, String, String) {
180    let client_id = anonymous_client_id(req).await;
181    let conn_id = Ulid::new();
182
183    let uri_path = req.uri().path().to_string();
184    let uri_query = req.uri().query().unwrap_or("").to_string();
185
186    let method = req.method().to_string();
187
188    let span = tracing::span!(
189        target: "Endpoint",
190        Level::INFO,
191        "request",
192        client = %client_id,
193        conn = %conn_id,
194        version = ?req.version(),
195        method = %method,
196        path = %uri_path,
197        query_size = field::Empty,
198        req_size = field::Empty,
199        req_content_type = field::Empty,
200        req_encoding = field::Empty,
201        resp_size = field::Empty,
202        resp_content_type = field::Empty,
203        resp_encoding = field::Empty,
204        endpoint = field::Empty,
205        duration_ms = field::Empty,
206        cpu_time_ms = field::Empty,
207        oid = field::Empty,
208        status = field::Empty,
209        panic = field::Empty,
210    );
211
212    // Record query size (To see if we are sent enormous queries).
213    if !uri_query.is_empty() {
214        span.record("query_size", uri_query.len());
215    }
216
217    add_interesting_headers_to_span(&span, "req", req.headers());
218
219    // Try and get the endpoint as a path pattern (this prevents metrics explosion).
220    if let Some(endpoint) = req.data::<PathPattern>() {
221        let endpoint = endpoint.0.trim_end_matches('/').to_string();
222        span.record("endpoint", endpoint);
223    }
224
225    // Try and get the endpoint as a path pattern (this prevents metrics explosion).
226    if let Some(oid) = req.data::<OperationId>() {
227        span.record("oid", oid.0.to_string());
228    }
229
230    (span, uri_path, method, client_id)
231}
232
233impl<E: Endpoint> Endpoint for TracingEndpoint<E> {
234    type Output = Response;
235
236    async fn call(&self, req: Request) -> Result<Self::Output> {
237        // Construct the span from the request.
238        let (span, uri_path, method, client_id) = mk_request_span(&req).await;
239
240        let inner_span = span.clone();
241
242        let (response, resp_data) = async move {
243            let now = Instant::now();
244            let now_proc = ProcessTime::now();
245
246            let resp = self.inner.call(req).await;
247
248            #[allow(clippy::cast_precision_loss)] // Precision loss is acceptable for this timing.
249            let duration_proc = now_proc.elapsed().as_micros() as f64 / 1000.0;
250
251            #[allow(clippy::cast_precision_loss)] // Precision loss is acceptable for this timing.
252            let duration = now.elapsed().as_micros() as f64 / 1000.0;
253
254            match resp {
255                Ok(resp) => {
256                    // let panic = if let Some(panic) = resp.downcast_ref::<ServerError>() {
257                    // Add panic ID to the span.
258                    //    Some(panic.id());
259                    //} else {
260                    //    None
261                    //};
262
263                    let resp = resp.into_response();
264
265                    let response_data =
266                        ResponseData::new(duration, duration_proc, &resp, None, &inner_span);
267
268                    (Ok(resp), response_data)
269                },
270                Err(err) => {
271                    // let panic = if let Some(panic) = err.downcast_ref::<ServerError>() {
272                    // Add panic ID to the span.
273                    //    Some(panic.id());
274                    //} else {
275                    //    None
276                    //};
277                    let panic: Option<Uuid> = None;
278
279                    // Get the error message, and also for now, log the error to ensure we are
280                    // preserving all data.
281                    error!(err = ?err, "HTTP Response Error:");
282                    // Only way I can see to get the message, may not be perfect.
283                    let error_message = err.to_string();
284
285                    // Convert the error into a response, so we can deal with the error
286                    let resp = err.into_response();
287
288                    let response_data =
289                        ResponseData::new(duration, duration_proc, &resp, panic, &inner_span);
290
291                    // Convert the response back to an error, and try and recover the message.
292                    let mut error = Error::from_response(resp);
293                    if !error_message.is_empty() {
294                        error.set_error_message(&error_message);
295                    }
296
297                    (Err(error), response_data)
298                },
299            }
300        }
301        .instrument(span.clone())
302        .await;
303
304        span.in_scope(|| {
305            // We really want to use the path_pattern from the response, but if not set use the path
306            // from the request.
307            let path = if resp_data.endpoint.is_empty() {
308                uri_path
309            } else {
310                resp_data.endpoint
311            };
312
313            HTTP_REQ_DURATION_MS
314                .with_label_values(&[&path, &method, &resp_data.status_code.to_string()])
315                .observe(resp_data.duration);
316            HTTP_REQ_CPU_TIME_MS
317                .with_label_values(&[&path, &method, &resp_data.status_code.to_string()])
318                .observe(resp_data.cpu_time);
319            // HTTP_REQUEST_RATE
320            //.with_label_values(&[&uri_path, &method, &response.status_code.to_string()])
321            //.inc();
322            HTTP_REQUEST_COUNT
323                .with_label_values(&[&path, &method, &resp_data.status_code.to_string()])
324                .inc();
325            CLIENT_REQUEST_COUNT
326                .with_label_values(&[&client_id, &resp_data.status_code.to_string()])
327                .inc();
328            // HTTP_REQUEST_SIZE_BYTES
329            //.with_label_values(&[&uri_path, &method, &response.status_code.to_string()])
330            //.observe(response.body().len() as f64);
331        });
332
333        response
334    }
335}