1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
//! Non-interactive Zero Knowledge proof of Discrete Logarithm
//! EQuality (DLEQ).
//!
//! The proof is the following:
//!
//! `NIZK{(base_1, base_2, point_1, point_2), (dlog): point_1 = base_1^dlog AND point_2 = base_2^dlog}`
//!
//! which makes the statement, the two bases `base_1` and `base_2`, and the two
//! points `point_1` and `point_2`. The witness, on the other hand
//! is the discrete logarithm, `dlog`.
#![allow(clippy::many_single_char_names)]
use super::challenge_context::ChallengeContext;
use crate::ec::ristretto255::{GroupElement, Scalar};
use rand_core::{CryptoRng, RngCore};

/// Proof of correct decryption.
/// Note: if the goal is to reduce the size of a proof, it is better to store the challenge
/// and the response. If on the other hand we want to allow for batch verification of
/// proofs, we should store the announcements and the response.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Zkp {
    challenge: Scalar,
    response: Scalar,
}

impl Zkp {
    pub const BYTES_LEN: usize = 2 * Scalar::BYTES_LEN;
    /// Generate a DLEQ proof
    pub fn generate<R>(
        base_1: &GroupElement,
        base_2: &GroupElement,
        point_1: &GroupElement,
        point_2: &GroupElement,
        dlog: &Scalar,
        rng: &mut R,
    ) -> Self
    where
        R: CryptoRng + RngCore,
    {
        let w = Scalar::random(rng);
        let announcement_1 = base_1 * &w;
        let announcement_2 = base_2 * &w;
        let mut challenge_context = ChallengeContext::new(base_1, base_2, point_1, point_2);
        let challenge = challenge_context.first_challenge(&announcement_1, &announcement_2);
        let response = dlog * &challenge + &w;

        Zkp {
            challenge,
            response,
        }
    }

    /// Verify a DLEQ proof
    pub fn verify(
        &self,
        base_1: &GroupElement,
        base_2: &GroupElement,
        point_1: &GroupElement,
        point_2: &GroupElement,
    ) -> bool {
        let r1 = base_1 * &self.response;
        let r2 = base_2 * &self.response;
        let announcement_1 = r1 - (point_1 * &self.challenge);
        let announcement_2 = r2 - (point_2 * &self.challenge);

        let mut challenge_context = ChallengeContext::new(base_1, base_2, point_1, point_2);
        let challenge = challenge_context.first_challenge(&announcement_1, &announcement_2);
        // no need for constant time equality because of the hash in challenge()
        challenge == self.challenge
    }

    pub fn to_bytes(&self) -> [u8; Self::BYTES_LEN] {
        let mut output = [0u8; Self::BYTES_LEN];
        self.write_to_bytes(&mut output);
        output
    }

    pub fn write_to_bytes(&self, output: &mut [u8]) {
        assert_eq!(output.len(), Self::BYTES_LEN);
        output[0..Scalar::BYTES_LEN].copy_from_slice(&self.challenge.to_bytes());
        output[Scalar::BYTES_LEN..].copy_from_slice(&self.response.to_bytes());
    }

    pub fn from_bytes(slice: &[u8]) -> Option<Self> {
        if slice.len() != Self::BYTES_LEN {
            return None;
        }
        let challenge = Scalar::from_bytes(&slice[..Scalar::BYTES_LEN])?;
        let response = Scalar::from_bytes(&slice[Scalar::BYTES_LEN..])?;

        let proof = Zkp {
            challenge,
            response,
        };
        Some(proof)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rand_core::OsRng;

    #[test]
    pub fn it_works() {
        let mut r: OsRng = OsRng;

        let dlog = Scalar::random(&mut r);
        let base_1 = GroupElement::from_hash(&[0u8]);
        let base_2 = GroupElement::from_hash(&[0u8]);
        let point_1 = &base_1 * &dlog;
        let point_2 = &base_2 * &dlog;

        let proof = Zkp::generate(&base_1, &base_2, &point_1, &point_2, &dlog, &mut r);

        assert!(proof.verify(&base_1, &base_2, &point_1, &point_2));
    }

    #[test]
    fn serialisation() {
        let mut r: OsRng = OsRng;

        let dlog = Scalar::random(&mut r);
        let base_1 = GroupElement::from_hash(&[0u8]);
        let base_2 = GroupElement::from_hash(&[0u8]);
        let point_1 = &base_1 * &dlog;
        let point_2 = &base_2 * &dlog;

        let proof = Zkp::generate(&base_1, &base_2, &point_1, &point_2, &dlog, &mut r);

        let serialised_proof = proof.to_bytes();
        let deserialised_proof = Zkp::from_bytes(&serialised_proof);

        assert!(deserialised_proof.is_some());

        assert!(deserialised_proof
            .unwrap()
            .verify(&base_1, &base_2, &point_1, &point_2));
    }
}