1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
//! C509 Extension as a part of `TBSCertificate` used in C509 Certificate.
//!
//! Extension fallback of C509 OID extension
//! Given OID if not found in the registered OID table, it will be encoded as a PEN OID.
//! If the OID is not a PEN OID, it will be encoded as an unwrapped OID.
//!
//! ```cddl
//! Extensions and Extension can be encoded as the following:
//! Extensions = [ * Extension ] / int
//! Extension = ( extensionID: int, extensionValue: any ) //
//! ( extensionID: ~oid, ? critical: true,
//!   extensionValue: bytes ) //
//! ( extensionID: pen, ? critical: true,
//!   extensionValue: bytes )
//! ```
//!
//! For more information about Extensions,
//! visit [C509 Certificate](https://datatracker.ietf.org/doc/draft-ietf-cose-cbor-encoded-cert/09/)

pub mod alt_name;
pub mod extension;

use std::fmt::Debug;

use asn1_rs::{oid, Oid};
use extension::{Extension, ExtensionValue};
use minicbor::{encode::Write, Decode, Decoder, Encode, Encoder};
use serde::{Deserialize, Serialize};

/// OID of `KeyUsage` extension
static KEY_USAGE_OID: Oid<'static> = oid!(2.5.29 .15);

/// A struct of C509 Extensions containing a vector of `Extension`.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct Extensions(Vec<Extension>);

impl Default for Extensions {
    fn default() -> Self {
        Self::new()
    }
}

impl Extensions {
    /// Create a new instance of `Extensions` as empty vector.
    #[must_use]
    pub fn new() -> Self {
        Self(Vec::new())
    }

    /// Add an `Extension` to the `Extensions`.
    pub fn add_ext(&mut self, extension: Extension) {
        self.0.push(extension);
    }

    /// Get the inner vector of `Extensions`.
    #[must_use]
    pub fn get_inner(&self) -> &Vec<Extension> {
        &self.0
    }
}

impl Encode<()> for Extensions {
    fn encode<W: Write>(
        &self, e: &mut Encoder<W>, ctx: &mut (),
    ) -> Result<(), minicbor::encode::Error<W::Error>> {
        // If there is only one extension and it is KeyUsage, encode as int
        // encoding as absolute value of the second int and the sign of the first int
        if let Some(extension) = self.0.first() {
            if self.0.len() == 1
                && extension.get_registered_oid().get_c509_oid().get_oid() == KEY_USAGE_OID
            {
                match extension.get_value() {
                    ExtensionValue::Int(value) => {
                        let ku_value = if extension.get_critical() {
                            -value
                        } else {
                            *value
                        };
                        e.i64(ku_value)?;
                        return Ok(());
                    },
                    _ => {
                        return Err(minicbor::encode::Error::message(
                            "KeyUsage extension value should be an integer",
                        ));
                    },
                }
            }
        }
        // Else handle the array of `Extension`
        e.array(self.0.len() as u64)?;
        for extension in &self.0 {
            extension.encode(e, ctx)?;
        }
        Ok(())
    }
}

impl Decode<'_, ()> for Extensions {
    fn decode(d: &mut Decoder<'_>, _ctx: &mut ()) -> Result<Self, minicbor::decode::Error> {
        // If only KeyUsage is in the extension -> will only contain an int
        if d.datatype()? == minicbor::data::Type::U8 || d.datatype()? == minicbor::data::Type::I8 {
            // Check if it's a negative number (critical extension)
            let critical = d.datatype()? == minicbor::data::Type::I8;
            // Note that 'KeyUsage' BIT STRING is interpreted as an unsigned integer,
            // so we can absolute the value
            let value = d.i64()?.abs();

            let extension_value = ExtensionValue::Int(value);
            let mut extensions = Extensions::new();
            extensions.add_ext(Extension::new(
                KEY_USAGE_OID.clone(),
                extension_value,
                critical,
            ));
            return Ok(extensions);
        }
        // Handle array of extensions
        let len = d
            .array()?
            .ok_or_else(|| minicbor::decode::Error::message("Failed to get array length"))?;
        let mut extensions = Extensions::new();
        for _ in 0..len {
            let extension = Extension::decode(d, &mut ())?;
            extensions.add_ext(extension);
        }

        Ok(extensions)
    }
}

// ------------------Test----------------------

#[cfg(test)]
mod test_extensions {
    use super::*;

    #[test]
    fn one_extension_key_usage() {
        let mut buffer = Vec::new();
        let mut encoder = Encoder::new(&mut buffer);

        let mut exts = Extensions::new();
        exts.add_ext(Extension::new(
            oid!(2.5.29 .15),
            ExtensionValue::Int(2),
            false,
        ));
        exts.encode(&mut encoder, &mut ())
            .expect("Failed to encode Extensions");
        // 1 extension
        // value 2 : 0x02
        assert_eq!(hex::encode(buffer.clone()), "02");

        let mut decoder = Decoder::new(&buffer);
        let decoded_exts =
            Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
        assert_eq!(decoded_exts, exts);
    }

    #[test]
    fn one_extension_key_usage_set_critical() {
        let mut buffer = Vec::new();
        let mut encoder = Encoder::new(&mut buffer);

        let mut exts = Extensions::new();
        exts.add_ext(Extension::new(
            oid!(2.5.29 .15),
            ExtensionValue::Int(2),
            true,
        ));
        exts.encode(&mut encoder, &mut ())
            .expect("Failed to encode Extensions");
        // 1 extension
        // value -2 : 0x21
        assert_eq!(hex::encode(buffer.clone()), "21");

        let mut decoder = Decoder::new(&buffer);
        let decoded_exts =
            Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
        assert_eq!(decoded_exts, exts);
    }

    #[test]
    fn multiple_extensions() {
        let mut buffer = Vec::new();
        let mut encoder = Encoder::new(&mut buffer);

        let mut exts = Extensions::new();
        exts.add_ext(Extension::new(
            oid!(2.5.29 .15),
            ExtensionValue::Int(2),
            false,
        ));

        exts.add_ext(Extension::new(
            oid!(2.5.29 .14),
            ExtensionValue::Bytes([1, 2, 3, 4].to_vec()),
            false,
        ));
        exts.encode(&mut encoder, &mut ())
            .expect("Failed to encode Extensions");

        // 2 extensions (array of 2): 0x82
        // KeyUsage with value 2: 0x0202
        // SubjectKeyIdentifier with value [1,2,3,4]: 0x0401020304
        assert_eq!(hex::encode(buffer.clone()), "820202014401020304");

        let mut decoder = Decoder::new(&buffer);
        let decoded_exts =
            Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
        assert_eq!(decoded_exts, exts);
    }

    #[test]
    fn zero_extensions() {
        let mut buffer = Vec::new();
        let mut encoder = Encoder::new(&mut buffer);

        let exts = Extensions::new();
        exts.encode(&mut encoder, &mut ())
            .expect("Failed to encode Extensions");
        assert_eq!(hex::encode(buffer.clone()), "80");

        let mut decoder = Decoder::new(&buffer);
        // Extensions can have 0 length
        let decoded_exts =
            Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
        assert_eq!(decoded_exts, exts);
    }
}