cat_gateway/service/utilities/middleware/
catch_panic.rs

1//! A middleware to catch and log panics.
2
3use std::{any::Any, backtrace::Backtrace, cell::RefCell, panic::AssertUnwindSafe};
4
5use futures::FutureExt;
6use panic_message::panic_message;
7use poem::{
8    Endpoint, IntoResponse, Middleware, Request, Response,
9    http::{HeaderMap, Method, StatusCode, Uri},
10};
11use poem_openapi::payload::Json;
12use tracing::{debug, error};
13
14use crate::{
15    service::{
16        common::responses::code_500_internal_server_error::InternalServerError,
17        utilities::health::{get_live_counter, inc_live_counter, set_not_live},
18    },
19    settings::Settings,
20};
21
22// Allows us to catch the backtrace so we can include it in logs.
23thread_local! {
24    static BACKTRACE: RefCell<Option<String>> = const { RefCell::new(None) };
25    static LOCATION: RefCell<Option<String>> = const { RefCell::new(None) };
26}
27
28/// Sets a custom panic hook to capture the Backtrace and Panic Location for logging
29/// purposes. This hook gets called BEFORE we catch it.  So the thread local variables
30/// stored here are valid when processing the panic capture.
31pub(crate) fn set_panic_hook() {
32    std::panic::set_hook(Box::new(|panic_info| {
33        // Get the backtrace and format it.
34        let raw_trace = Backtrace::force_capture();
35        let trace = format!("{raw_trace}");
36        BACKTRACE.with(move |b| b.borrow_mut().replace(trace));
37
38        // Get the location and format it.
39        let location = match panic_info.location() {
40            Some(location) => format!("{location}"),
41            None => "Unknown".to_string(),
42        };
43        LOCATION.with(move |l| l.borrow_mut().replace(location));
44    }));
45}
46
47/// A middleware for catching and logging panics transforming them to the 500 error
48/// response. The panic will not crash the service, but it becomes not live after
49/// exceeding the `Settings::service_live_counter_threshold()` value. That should cause
50/// Kubernetes to restart the service.
51pub struct CatchPanicMiddleware {}
52
53impl CatchPanicMiddleware {
54    /// Creates a new middleware instance.
55    pub fn new() -> Self {
56        Self {}
57    }
58}
59
60impl<E: Endpoint> Middleware<E> for CatchPanicMiddleware {
61    type Output = CatchPanicEndpoint<E>;
62
63    fn transform(
64        &self,
65        ep: E,
66    ) -> Self::Output {
67        CatchPanicEndpoint { inner: ep }
68    }
69}
70
71/// An endpoint for the `CatchPanicMiddleware` middleware.
72pub struct CatchPanicEndpoint<E> {
73    /// An inner endpoint.
74    inner: E,
75}
76
77impl<E: Endpoint> Endpoint for CatchPanicEndpoint<E> {
78    type Output = Response;
79
80    async fn call(
81        &self,
82        req: Request,
83    ) -> poem::Result<Self::Output> {
84        // Preserve all the data that we want to potentially log because a request is consumed.
85        let method = req.method().clone();
86        let uri = req.uri().clone();
87        let headers = req.headers().clone();
88
89        match AssertUnwindSafe(self.inner.call(req)).catch_unwind().await {
90            Ok(resp) => resp.map(IntoResponse::into_response),
91            Err(err) => Ok(panic_response(&err, &method, &uri, &headers)),
92        }
93    }
94}
95
96/// Converts a panic to a response.
97fn panic_response(
98    err: &Box<dyn Any + Send + 'static>,
99    method: &Method,
100    uri: &Uri,
101    headers: &HeaderMap,
102) -> Response {
103    // Increment the counter used for liveness checks.
104    inc_live_counter();
105
106    let current_count = get_live_counter();
107    debug!(
108        live_counter = current_count,
109        "Handling service panic response"
110    );
111
112    // If current count is above the threshold, then flag the system as NOT live.
113    if current_count > Settings::service_live_counter_threshold() {
114        set_not_live();
115    }
116
117    let server_err = InternalServerError::new(None);
118
119    // Get the unique identifier for this panic, so we can find it in the logs.
120    let panic_identifier = server_err.id().to_string();
121
122    // Get the message from the panic as best we can.
123    let err_msg = panic_message(err);
124
125    // This is the location of the panic.
126    let location = match LOCATION.with(|l| l.borrow_mut().take()) {
127        Some(location) => location,
128        None => "Unknown".to_string(),
129    };
130
131    // This is the backtrace of the panic.
132    let backtrace = match BACKTRACE.with(|b| b.borrow_mut().take()) {
133        Some(backtrace) => backtrace,
134        None => "Unknown".to_string(),
135    };
136
137    error!(
138        panic = true,
139        backtrace = backtrace,
140        location = location,
141        id = panic_identifier,
142        message = err_msg,
143        method = ?method,
144        uri = ?uri,
145        headers = ?headers,
146    );
147
148    let mut resp = Json(server_err).into_response();
149    resp.set_status(StatusCode::INTERNAL_SERVER_ERROR);
150    resp
151}