mod data;
use std::{fmt::Debug, str::FromStr};
use asn1_rs::Oid;
use data::{get_extension_type_from_int, get_oid_from_int, EXTENSIONS_LOOKUP};
use minicbor::{encode::Write, Decode, Decoder, Encode, Encoder};
use serde::{Deserialize, Deserializer, Serialize};
use strum_macros::EnumDiscriminants;
use super::alt_name::AlternativeName;
use crate::oid::{C509oid, C509oidRegistered};
#[derive(Debug, Clone, PartialEq)]
pub struct Extension {
registered_oid: C509oidRegistered,
critical: bool,
value: ExtensionValue,
}
impl Extension {
#[must_use]
pub fn new(oid: Oid<'static>, value: ExtensionValue, critical: bool) -> Self {
Self {
registered_oid: C509oidRegistered::new(oid, EXTENSIONS_LOOKUP.get_int_to_oid_table())
.pen_encoded(),
critical,
value,
}
}
#[must_use]
pub fn get_value(&self) -> &ExtensionValue {
&self.value
}
#[must_use]
pub fn get_critical(&self) -> bool {
self.critical
}
#[must_use]
pub fn get_registered_oid(&self) -> &C509oidRegistered {
&self.registered_oid
}
}
#[derive(Debug, Deserialize, Serialize)]
struct Helper {
oid: String,
value: ExtensionValue,
critical: bool,
}
impl<'de> Deserialize<'de> for Extension {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> {
let helper = Helper::deserialize(deserializer)?;
let oid =
Oid::from_str(&helper.oid).map_err(|e| serde::de::Error::custom(format!("{e:?}")))?;
Ok(Extension::new(oid, helper.value, helper.critical))
}
}
impl Serialize for Extension {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer {
let helper = Helper {
oid: self.registered_oid.get_c509_oid().get_oid().to_string(),
value: self.value.clone(),
critical: self.critical,
};
helper.serialize(serializer)
}
}
impl Encode<()> for Extension {
fn encode<W: Write>(
&self, e: &mut Encoder<W>, ctx: &mut (),
) -> Result<(), minicbor::encode::Error<W::Error>> {
if let Some(&mapped_oid) = self
.registered_oid
.get_table()
.get_map()
.get_by_right(&self.registered_oid.get_c509_oid().get_oid())
{
let encoded_oid = if self.critical {
-mapped_oid
} else {
mapped_oid
};
e.i16(encoded_oid)?;
} else {
self.registered_oid.get_c509_oid().encode(e, ctx)?;
if self.critical {
e.bool(self.critical)?;
}
}
self.value.encode(e, ctx)?;
Ok(())
}
}
impl Decode<'_, ()> for Extension {
fn decode(d: &mut Decoder<'_>, ctx: &mut ()) -> Result<Self, minicbor::decode::Error> {
match d.datatype()? {
minicbor::data::Type::U8
| minicbor::data::Type::U16
| minicbor::data::Type::I8
| minicbor::data::Type::I16 => {
let int_value = d.i16()?;
let abs_int_value = int_value.abs();
let oid =
get_oid_from_int(abs_int_value).map_err(minicbor::decode::Error::message)?;
let value_type = get_extension_type_from_int(abs_int_value)
.map_err(minicbor::decode::Error::message)?;
let extension_value = ExtensionValue::decode(d, &mut value_type.get_type())?;
Ok(Extension::new(
oid.to_owned(),
extension_value,
int_value.is_negative(),
))
},
_ => {
let c509_oid = C509oid::decode(d, ctx)?;
let critical = if d.datatype()? == minicbor::data::Type::Bool {
d.bool()?
} else {
false
};
let extension_value = ExtensionValue::Bytes(d.bytes()?.to_vec());
Ok(Extension::new(
c509_oid.get_oid(),
extension_value,
critical,
))
},
}
}
}
trait ExtensionValueTypeTrait {
fn get_type(&self) -> ExtensionValueType;
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, PartialEq, EnumDiscriminants, Deserialize, Serialize)]
#[strum_discriminants(name(ExtensionValueType))]
#[serde(rename_all = "snake_case")]
pub enum ExtensionValue {
Int(i64),
Bytes(Vec<u8>),
AlternativeName(AlternativeName),
Unsupported,
}
impl ExtensionValueTypeTrait for ExtensionValueType {
fn get_type(&self) -> ExtensionValueType {
*self
}
}
impl Encode<()> for ExtensionValue {
fn encode<W: Write>(
&self, e: &mut Encoder<W>, ctx: &mut (),
) -> Result<(), minicbor::encode::Error<W::Error>> {
match self {
ExtensionValue::Int(value) => {
e.i64(*value)?;
},
ExtensionValue::Bytes(value) => {
e.bytes(value)?;
},
ExtensionValue::AlternativeName(value) => {
value.encode(e, ctx)?;
},
ExtensionValue::Unsupported => {
return Err(minicbor::encode::Error::message(
"Cannot encode unsupported Extension value",
));
},
}
Ok(())
}
}
impl<C> Decode<'_, C> for ExtensionValue
where C: ExtensionValueTypeTrait + Debug
{
fn decode(d: &mut Decoder<'_>, ctx: &mut C) -> Result<Self, minicbor::decode::Error> {
match ctx.get_type() {
ExtensionValueType::Int => {
let value = d.i64()?;
Ok(ExtensionValue::Int(value))
},
ExtensionValueType::Bytes => {
let value = d.bytes()?.to_vec();
Ok(ExtensionValue::Bytes(value))
},
ExtensionValueType::AlternativeName => {
let value = AlternativeName::decode(d, &mut ())?;
Ok(ExtensionValue::AlternativeName(value))
},
ExtensionValueType::Unsupported => {
Err(minicbor::decode::Error::message(
"Cannot decode Unsupported extension value",
))
},
}
}
}
#[cfg(test)]
mod test_extension {
use asn1_rs::oid;
use super::*;
#[test]
fn int_oid_inhibit_anypolicy_value_unsigned_int() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let ext = Extension::new(oid!(2.5.29 .54), ExtensionValue::Int(2), false);
ext.encode(&mut encoder, &mut ())
.expect("Failed to encode Extension");
assert_eq!(hex::encode(buffer.clone()), "181e02");
let mut decoder = Decoder::new(&buffer);
let decoded_ext =
Extension::decode(&mut decoder, &mut ()).expect("Failed to decode Extension");
assert_eq!(decoded_ext, ext);
}
#[test]
fn unwrapped_oid_critical_key_usage_value_int() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let ext = Extension::new(oid!(2.5.29 .15), ExtensionValue::Int(-1), true);
ext.encode(&mut encoder, &mut ())
.expect("Failed to encode Extension");
assert_eq!(hex::encode(buffer.clone()), "2120");
let mut decoder = Decoder::new(&buffer);
let decoded_ext =
Extension::decode(&mut decoder, &mut ()).expect("Failed to decode Extension");
assert_eq!(decoded_ext, ext);
}
#[test]
fn oid_unwrapped_value_bytes_string() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let ext = Extension::new(
oid!(2.16.840 .1 .101 .3 .4 .2 .1),
ExtensionValue::Bytes("test".as_bytes().to_vec()),
false,
);
ext.encode(&mut encoder, &mut ())
.expect("Failed to encode Extension");
assert_eq!(
hex::encode(buffer.clone()),
"496086480165030402014474657374"
);
let mut decoder = Decoder::new(&buffer);
let decoded_ext =
Extension::decode(&mut decoder, &mut ()).expect("Failed to decode Extension");
assert_eq!(decoded_ext, ext);
}
#[test]
fn encode_decode_mismatch_type() {
let mut buffer = Vec::new();
let mut encoder = Encoder::new(&mut buffer);
let ext = Extension::new(oid!(2.5.29 .14), ExtensionValue::Int(2), false);
ext.encode(&mut encoder, &mut ())
.expect("Failed to encode Extension");
assert_eq!(hex::encode(buffer.clone()), "0102");
let mut decoder = Decoder::new(&buffer);
Extension::decode(&mut decoder, &mut ()).expect_err("Failed to decode Extension");
}
}