1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
use std::time::{Duration, Instant};

use governor::{
    clock::DefaultClock,
    state::{InMemoryState, NotKeyed},
    Quota, RateLimiter,
};

use color_eyre::Report;
use serde::Deserialize;
use tracing::debug;

use super::{HttpClient, HttpResponse};

#[derive(Debug)]
pub struct RateLimitClient<T: HttpClient> {
    inner: T,
    limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
}

impl<T: HttpClient> RateLimitClient<T> {
    pub fn new(inner: T, request_interval_ms: u64) -> Self {
        let limiter = if request_interval_ms == 0 {
            None
        } else {
            let quota = Quota::with_period(Duration::from_millis(request_interval_ms)).unwrap();
            Some(RateLimiter::direct(quota))
        };
        Self { inner, limiter }
    }
}

impl<T: HttpClient> HttpClient for RateLimitClient<T> {
    fn get<S>(&self, path: &str) -> Result<HttpResponse<S>, Report>
    where
        S: for<'a> Deserialize<'a>,
    {
        if let Some(limiter) = &self.limiter {
            while let Err(e) = limiter.check() {
                let time = e.wait_time_from(Instant::now());
                debug!("waiting for rate limit: {time:?}");
                std::thread::sleep(time);
            }
        }
        self.inner.get(path)
    }
}