pub mod alt_name;
pub mod extension;
use std::fmt::Debug;
use asn1_rs::{oid, Oid};
use extension::{Extension, ExtensionValue};
use minicbor::{encode::Write, Decode, Decoder, Encode, Encoder};
use serde::{Deserialize, Serialize};
static KEY_USAGE_OID: Oid<'static> = oid!(2.5.29 .15);
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct Extensions(Vec<Extension>);
impl Default for Extensions {
fn default() -> Self {
Self::new()
}
}
impl Extensions {
#[must_use]
pub fn new() -> Self {
Self(Vec::new())
}
pub fn add_ext(&mut self, extension: Extension) {
self.0.push(extension);
}
#[must_use]
pub fn get_inner(&self) -> &Vec<Extension> {
&self.0
}
}
impl Encode<()> for Extensions {
fn encode<W: Write>(
&self, e: &mut Encoder<W>, ctx: &mut (),
) -> Result<(), minicbor::encode::Error<W::Error>> {
if let Some(extension) = self.0.first() {
if self.0.len() == 1
&& extension.get_registered_oid().get_c509_oid().get_oid() == KEY_USAGE_OID
{
match extension.get_value() {
ExtensionValue::Int(value) => {
let ku_value = if extension.get_critical() {
-value
} else {
*value
};
e.i64(ku_value)?;
return Ok(());
},
_ => {
return Err(minicbor::encode::Error::message(
"KeyUsage extension value should be an integer",
));
},
}
}
}
e.array(self.0.len() as u64)?;
for extension in &self.0 {
extension.encode(e, ctx)?;
}
Ok(())
}
}
impl Decode<'_, ()> for Extensions {
fn decode(d: &mut Decoder<'_>, _ctx: &mut ()) -> Result<Self, minicbor::decode::Error> {
if d.datatype()? == minicbor::data::Type::U8 || d.datatype()? == minicbor::data::Type::I8 {
let critical = d.datatype()? == minicbor::data::Type::I8;
let value = d.i64()?.abs();
let extension_value = ExtensionValue::Int(value);
let mut extensions = Extensions::new();
extensions.add_ext(Extension::new(
KEY_USAGE_OID.clone(),
extension_value,
critical,
));
return Ok(extensions);
}
let len = d
.array()?
.ok_or_else(|| minicbor::decode::Error::message("Failed to get array length"))?;
let mut extensions = Extensions::new();
for _ in 0..len {
let extension = Extension::decode(d, &mut ())?;
extensions.add_ext(extension);
}
Ok(extensions)
}
}
#[cfg(test)]
mod test_extensions {
use super::*;
#[test]
fn one_extension_key_usage() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let mut exts = Extensions::new();
exts.add_ext(Extension::new(
oid!(2.5.29 .15),
ExtensionValue::Int(2),
false,
));
exts.encode(&mut encoder, &mut ())
.expect("Failed to encode Extensions");
assert_eq!(hex::encode(buffer.clone()), "02");
let mut decoder = Decoder::new(&buffer);
let decoded_exts =
Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
assert_eq!(decoded_exts, exts);
}
#[test]
fn one_extension_key_usage_set_critical() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let mut exts = Extensions::new();
exts.add_ext(Extension::new(
oid!(2.5.29 .15),
ExtensionValue::Int(2),
true,
));
exts.encode(&mut encoder, &mut ())
.expect("Failed to encode Extensions");
assert_eq!(hex::encode(buffer.clone()), "21");
let mut decoder = Decoder::new(&buffer);
let decoded_exts =
Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
assert_eq!(decoded_exts, exts);
}
#[test]
fn multiple_extensions() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let mut exts = Extensions::new();
exts.add_ext(Extension::new(
oid!(2.5.29 .15),
ExtensionValue::Int(2),
false,
));
exts.add_ext(Extension::new(
oid!(2.5.29 .14),
ExtensionValue::Bytes([1, 2, 3, 4].to_vec()),
false,
));
exts.encode(&mut encoder, &mut ())
.expect("Failed to encode Extensions");
assert_eq!(hex::encode(buffer.clone()), "820202014401020304");
let mut decoder = Decoder::new(&buffer);
let decoded_exts =
Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
assert_eq!(decoded_exts, exts);
}
#[test]
fn zero_extensions() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let exts = Extensions::new();
exts.encode(&mut encoder, &mut ())
.expect("Failed to encode Extensions");
assert_eq!(hex::encode(buffer.clone()), "80");
let mut decoder = Decoder::new(&buffer);
let decoded_exts =
Extensions::decode(&mut decoder, &mut ()).expect("Failed to decode Extensions");
assert_eq!(decoded_exts, exts);
}
}