1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
//! REST API of the node
#[cfg(feature = "prometheus-metrics")]
mod prometheus;
pub mod v0;
mod v1;

use crate::context::{Context, ContextLock, ServerStopper};
use futures::{channel::mpsc, prelude::*};
use jormungandr_lib::interfaces::{Cors, Tls};
use std::{error::Error, net::SocketAddr, time::Duration};
use warp::Filter;

pub struct Config {
    pub listen: SocketAddr,
    pub tls: Option<Tls>,
    pub cors: Option<Cors>,
    #[cfg(feature = "prometheus-metrics")]
    pub enable_prometheus: bool,
}

pub async fn start_rest_server(config: Config, context: ContextLock) {
    let (stopper_tx, stopper_rx) = mpsc::channel::<()>(0);
    let stopper_rx = stopper_rx.into_future().map(|_| ());
    context
        .write()
        .await
        .set_rest_server_stopper(ServerStopper::new(stopper_tx));
    let api = v0::filter(context.clone()).or(v1::filter(context.clone()));

    let api = warp::path!("api" / ..)
        .and(api)
        .with(warp::filters::trace::trace(|info| {
            use http_zipkin::get_trace_context;
            use tracing::field::Empty;
            let span = tracing::span!(
                tracing::Level::DEBUG,
                "rest_api_request",
                method = %info.method(),
                path = info.path(),
                version = ?info.version(),
                remote_addr = Empty,
                trace_id = Empty,
                span_id = Empty,
                parent_span_id = Empty,
            );
            if let Some(remote_addr) = info.remote_addr() {
                span.record("remote_addr", remote_addr.to_string().as_str());
            }
            if let Some(trace_context) = get_trace_context(info.request_headers()) {
                span.record("trace_id", trace_context.trace_id().to_string().as_str());
                span.record("span_id", trace_context.span_id().to_string().as_str());
                if let Some(parent_span_id) = trace_context.parent_id() {
                    span.record("parent_span_id", parent_span_id.to_string().as_str());
                }
            }
            span
        }));

    setup_prometheus(api, config, context, stopper_rx).await;
}

#[cfg(feature = "prometheus-metrics")]
async fn setup_prometheus<App>(
    app: App,
    config: Config,
    context: ContextLock,
    shutdown_signal: impl Future<Output = ()> + Send + 'static,
) where
    App: Filter<Error = warp::Rejection> + Clone + Send + Sync + 'static,
    App::Extract: warp::Reply,
{
    if config.enable_prometheus {
        let prometheus = prometheus::filter(context.clone());
        setup_cors(app.or(prometheus), config, shutdown_signal).await;
    } else {
        setup_cors(app, config, shutdown_signal).await;
    }
}

#[cfg(not(feature = "prometheus-metrics"))]
async fn setup_prometheus<App>(
    app: App,
    config: Config,
    _context: ContextLock,
    shutdown_signal: impl Future<Output = ()> + Send + 'static,
) where
    App: Filter<Error = warp::Rejection> + Clone + Send + Sync + 'static,
    App::Extract: warp::Reply,
{
    setup_cors(app, config, shutdown_signal).await;
}

async fn setup_cors<App>(
    app: App,
    config: Config,
    shutdown_signal: impl Future<Output = ()> + Send + 'static,
) where
    App: Filter<Error = warp::Rejection> + Clone + Send + Sync + 'static,
    App::Extract: warp::Reply,
{
    tracing::info!(listen_address = %config.listen, "listening for REST API requests");

    if let Some(cors_config) = config.cors {
        let allowed_origins: Vec<&str> = cors_config
            .allowed_origins
            .iter()
            .map(AsRef::as_ref)
            .collect();

        let mut cors = warp::cors().allow_origins(allowed_origins);

        for header in cors_config.allowed_headers {
            cors = cors.allow_header(header.0);
        }

        for method in cors_config.allowed_methods {
            cors = cors.allow_method(method.0);
        }

        if let Some(max_age) = cors_config.max_age_secs {
            cors = cors.max_age(Duration::from_secs(max_age));
        }

        run_server_with_app(
            app.with(cors.build()),
            config.listen,
            config.tls,
            shutdown_signal,
        )
        .await;
    } else {
        run_server_with_app(app, config.listen, config.tls, shutdown_signal).await;
    }
}

async fn run_server_with_app<App>(
    app: App,
    listen_addr: SocketAddr,
    tls_config: Option<Tls>,
    shutdown_signal: impl Future<Output = ()> + Send + 'static,
) where
    App: Filter<Error = warp::Rejection> + Clone + Send + Sync + 'static,
    App::Extract: warp::Reply,
{
    let server = warp::serve(app);
    if let Some(tls_config) = tls_config {
        let (_, server_fut) = server
            .tls()
            .cert_path(tls_config.cert_file)
            .key_path(tls_config.priv_key_file)
            .bind_with_graceful_shutdown(listen_addr, shutdown_signal);
        server_fut.await;
    } else {
        let (_, server_fut) = server.bind_with_graceful_shutdown(listen_addr, shutdown_signal);
        server_fut.await;
    };
}

pub(self) fn display_internal_server_error(err: &impl Error) -> String {
    use std::fmt::{self, Write};

    fn error_to_body(err: &impl Error) -> Result<String, fmt::Error> {
        let mut reply_body = String::new();
        writeln!(reply_body, "Internal server error: {}", err)?;
        let mut source = err.source();
        while let Some(err) = source {
            writeln!(reply_body, "-> {}", err)?;
            source = err.source();
        }
        Ok(reply_body)
    }

    error_to_body(err).unwrap_or_else(|err| format!("failed to process internal error: {}", err))
}