1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
use chain_impl_mockchain::rewards;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{fmt, num::NonZeroU64, str::FromStr};
use thiserror::Error;

/// Ratio in the blockchain.
///
/// for example, used to represent the ratio of a setting in the stake pool
/// registration certificate.
///
#[derive(Debug, Clone, Copy)]
pub struct Ratio(rewards::Ratio);

impl Ratio {
    pub const fn new(numerator: u64, denominator: NonZeroU64) -> Self {
        Ratio(rewards::Ratio {
            numerator,
            denominator,
        })
    }

    pub fn new_checked(numerator: u64, denominator: u64) -> Option<Self> {
        NonZeroU64::new(denominator).map(move |denominator| Self::new(numerator, denominator))
    }
}

/* ---------------- Display ------------------------------------------------ */

impl fmt::Display for Ratio {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(
            f,
            "{numerator}/{denominator}",
            numerator = self.0.numerator,
            denominator = self.0.denominator
        )
    }
}

#[derive(Clone, Debug, Error)]
pub enum ParseRatioError {
    #[error("{0}")]
    InvalidInt(#[from] std::num::ParseIntError),

    #[error("Missing numerator part of the Ratio")]
    MissingNumerator,

    #[error("Missing denominator part of the Ratio")]
    MissingDenominator,
}

impl FromStr for Ratio {
    type Err = ParseRatioError;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let mut split = s.split('/');

        let numerator = if let Some(numerator) = split.next() {
            numerator.parse::<u64>()?
        } else {
            return Err(ParseRatioError::MissingNumerator);
        };

        let denominator = if let Some(denominator) = split.next() {
            denominator.parse::<NonZeroU64>()?
        } else {
            return Err(ParseRatioError::MissingNumerator);
        };

        Ok(Ratio(rewards::Ratio {
            numerator,
            denominator,
        }))
    }
}

/* ---------------- Comparison ---------------------------------------------- */

impl PartialEq<Self> for Ratio {
    fn eq(&self, other: &Self) -> bool {
        self.0.numerator == other.0.numerator && self.0.denominator == other.0.denominator
    }
}

impl Eq for Ratio {}

/* ---------------- AsRef -------------------------------------------------- */

impl AsRef<rewards::Ratio> for Ratio {
    fn as_ref(&self) -> &rewards::Ratio {
        &self.0
    }
}

/* ---------------- Conversion --------------------------------------------- */

impl From<rewards::Ratio> for Ratio {
    fn from(v: rewards::Ratio) -> Self {
        Ratio(v)
    }
}

impl From<Ratio> for rewards::Ratio {
    fn from(v: Ratio) -> Self {
        v.0
    }
}

/* ------------------- Serde ----------------------------------------------- */

impl Serialize for Ratio {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        self.to_string().serialize(serializer)
    }
}

impl<'de> Deserialize<'de> for Ratio {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        use serde::de::Error as _;

        String::deserialize(deserializer)
            .map_err(D::Error::custom)
            .and_then(|s| s.parse().map_err(D::Error::custom))
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use quickcheck::{Arbitrary, Gen, TestResult};
    use std::num::NonZeroU64;

    impl Arbitrary for Ratio {
        fn arbitrary<G>(g: &mut G) -> Self
        where
            G: Gen,
        {
            Ratio(rewards::Ratio {
                numerator: Arbitrary::arbitrary(g),
                denominator: NonZeroU64::new(Arbitrary::arbitrary(g))
                    .unwrap_or_else(|| NonZeroU64::new(1).unwrap()),
            })
        }
    }

    #[test]
    fn value_display_as_u64() {
        const NUMERATOR: u64 = 928_170;
        const DENOMINATOR: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(1291) };
        let ratio = Ratio(rewards::Ratio {
            numerator: NUMERATOR,
            denominator: DENOMINATOR,
        });

        assert_eq!(ratio.to_string(), format!("{}/{}", NUMERATOR, DENOMINATOR))
    }

    #[test]
    fn value_serde_as_u64() {
        const NUMERATOR: u64 = 928_170;
        const DENOMINATOR: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(1291) };
        let ratio = Ratio(rewards::Ratio {
            numerator: NUMERATOR,
            denominator: DENOMINATOR,
        });

        assert_eq!(
            serde_yaml::to_string(&ratio).unwrap(),
            format!("---\n{}/{}\n", NUMERATOR, DENOMINATOR)
        );
    }

    quickcheck! {
        fn value_display_parse(value: Ratio) -> TestResult {
            let s = value.to_string();
            let value_dec: Ratio = s.parse().unwrap();

            TestResult::from_bool(value_dec == value)
        }

        fn value_serde_human_readable_encode_decode(value: Ratio) -> TestResult {
            let s = serde_yaml::to_string(&value).unwrap();
            let value_dec: Ratio = serde_yaml::from_str(&s).unwrap();

            TestResult::from_bool(value_dec == value)
        }

        fn value_serde_binary_encode_decode(value: Ratio) -> TestResult {
            let s = bincode::serialize(&value).unwrap();
            let value_dec: Ratio = bincode::deserialize(&s).unwrap();

            TestResult::from_bool(value_dec == value)
        }
    }
}