1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use crate::db::{queries::api_tokens as api_tokens_queries, DbConnectionPool};
use crate::v0::{context::SharedContext, errors::HandleError};
use warp::{Filter, Rejection};

/// Header where token should be present in requests
pub const API_TOKEN_HEADER: &str = "API-Token";

/// API Token wrapper type
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct ApiToken(Vec<u8>);

/// API token manager is an abstraction on the API tokens for the service
/// The main idea is to keep the service agnostic of what kind of backend we are using such task.
/// Right now we rely on a SQLlite connection. But in the future it maybe be something else like a
/// REDIS, or some other hybrid system.
pub struct ApiTokenManager {
    connection_pool: DbConnectionPool,
}

impl From<&[u8]> for ApiToken {
    fn from(data: &[u8]) -> Self {
        Self(data.to_vec())
    }
}

impl AsRef<[u8]> for ApiToken {
    fn as_ref(&self) -> &[u8] {
        self.0.as_slice()
    }
}

impl ApiToken {
    pub fn new(data: Vec<u8>) -> Self {
        Self(data)
    }
}

impl ApiTokenManager {
    fn new(connection_pool: DbConnectionPool) -> Self {
        Self { connection_pool }
    }

    async fn is_token_valid(&self, token: ApiToken) -> Result<bool, HandleError> {
        match api_tokens_queries::query_token(token, &self.connection_pool).await {
            Ok(Some(_)) => Ok(true),
            Ok(None) => Ok(false),
            Err(e) => Err(HandleError::InternalError(format!(
                "Error retrieving token: {}",
                e
            ))),
        }
    }

    #[allow(dead_code)]
    async fn revoke_token(&self, _token: ApiToken) -> Result<(), ()> {
        Ok(())
    }
}

async fn authorize_token(token: String, context: SharedContext) -> Result<(), Rejection> {
    let manager = ApiTokenManager::new(context.read().await.db_connection_pool.clone());

    let mut token_vec: Vec<u8> = Vec::new();
    base64::decode_config_buf(token.clone(), base64::URL_SAFE, &mut token_vec).map_err(|_err| {
        warp::reject::custom(HandleError::InvalidHeader(
            API_TOKEN_HEADER,
            "header should be base64 url safe decodable",
        ))
    })?;

    let api_token = ApiToken(token_vec);

    match manager.is_token_valid(api_token).await {
        Ok(true) => Ok(()),
        Ok(false) => {
            tracing::event!(
                tracing::Level::INFO,
                "Unauthorized token received: {}",
                token
            );
            Err(warp::reject::custom(HandleError::UnauthorizedToken))
        }
        Err(e) => Err(warp::reject::custom(e)),
    }
}

/// A warp filter that checks authorization through API tokens.
/// The header `API_TOKEN_HEADER` should be present and valid otherwise the request is rejected.
pub async fn api_token_filter(
    context: SharedContext,
) -> impl Filter<Extract = (), Error = Rejection> + Clone {
    let with_context = warp::any().map(move || context.clone());
    warp::header::header(API_TOKEN_HEADER)
        .and(with_context)
        .and_then(authorize_token)
        .and(warp::any())
        .untuple_one()
}

#[cfg(test)]
mod test {
    use crate::v0::api_token::{api_token_filter, API_TOKEN_HEADER};
    use crate::v0::context::test::new_test_shared_context_from_url;
    use vit_servicing_station_tests::common::startup::db::DbBuilder;

    #[tokio::test]
    async fn api_token_filter_reject() {
        let db_url = DbBuilder::new().build_async().await.unwrap();
        let shared_context = new_test_shared_context_from_url(&db_url);
        let filter = api_token_filter(shared_context).await;

        assert!(warp::test::request()
            .header(API_TOKEN_HEADER, "foobar")
            .filter(&filter)
            .await
            .is_err());
    }
}