byte_string_derive/
lib.rs

1//! Proc macros for generating conversions for byte string wrapper types.
2//! See documentation of [macro@byte_string] for more information.
3use proc_macro::TokenStream;
4use quote::ToTokens;
5use quote::quote;
6use syn::Generics;
7
8extern crate alloc;
9
10/// Proc macro that can generate multiple helper functions and traits for types that wrap an
11/// array, [Vec], or [BoundedVec] of bytes. What code is generated by the macro is controlled
12/// by passing the following arguments:
13/// - `debug`:
14///   implements [Debug] that uses hex to encode the bytes
15/// - `hex_serialize`:
16///   implements `serde::Serialize` that encode the bytes to a hex string as intermediate format.
17///   This implementation is useful for saving bytes in Json and similar formats.
18/// - `hex_deserialize`:
19///   implements `serde::Deserialize` that decodes the type from a string containig hex-encoded
20///   bytes. This implementation is useful for decoding hex data from Json and similar formats.
21/// - `from_num`:
22///   implements [`From<u64>`]
23/// - `from_bytes`:
24///   implements either [`From<&[u8]>`] or [`TryFrom<&[u8]`] depending on whether the the type
25///   can be infallibly ([Vec]) or fallibly (array, [BoundedVec]) cast from a byte slice.
26/// - `decode_hex`:
27///   adds functions `decode_hex` and `decode_hex_unsafe` that decode the type from a [&str]
28///   and a [FromStr] implementation equivalent to calling `decode_hex`.
29/// - `to_hex_string`:
30///   adds a function `to_hex_string` that returns the inner bytes as a hex string.
31/// - `as_ref`: generates [`AsRef<[u8]>`] implementation
32///
33/// _Note_: As the code generated by this macro is meant to be `no_std`-compatible, some of the
34///         options require `alloc` crate to be available in the scope where the code is generated.
35///
36/// # Example
37///
38/// ```rust
39/// extern crate alloc;
40/// use byte_string_derive::byte_string;
41/// use sp_core::{ ConstU32, bounded_vec::BoundedVec };
42///
43/// #[byte_string(debug, hex_serialize, hex_deserialize, from_num, from_bytes, decode_hex, to_hex_string, as_ref)]
44/// pub struct MyArrayBytes([u8; 32]);
45///
46/// #[byte_string(debug, hex_serialize, hex_deserialize, from_num, from_bytes, decode_hex, to_hex_string, as_ref)]
47/// pub struct MyVecBytes(Vec<u8>);
48///
49/// #[byte_string(from_bytes)]
50/// pub struct MyBoundedVecBytes(BoundedVec<u8, ConstU32<32>>);
51/// ```
52///
53/// [BoundedVec]: https://paritytech.github.io/polkadot-sdk/master/sp_core/bounded_vec/struct.BoundedVec.html
54/// [FromStr]: std::str::FromStr
55#[proc_macro_attribute]
56pub fn byte_string(attr: TokenStream, input: TokenStream) -> TokenStream {
57	let ast = syn::parse(input).expect("Cannot parse source");
58
59	impl_byte_string_derive(attr, &ast)
60}
61
62enum SupportedType {
63	Array(syn::Type),
64	Vec,
65	BoundedVec(syn::Type),
66}
67
68fn impl_byte_string_derive(attr: TokenStream, ast: &syn::DeriveInput) -> TokenStream {
69	let name = &ast.ident;
70	let generics = &ast.generics;
71
72	let data = &ast.data;
73
74	let syn::Data::Struct(ds) = data else {
75		return quote! { compile_error!("byte_string is only defined for structs") }.into();
76	};
77
78	let ty = match &ds.fields {
79		syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
80			let field = fields.unnamed.first().unwrap();
81			field.ty.clone()
82		},
83		_ => return quote! { compile_error!("byte_string has to have one field") }.into(),
84	};
85
86	let supported_type = match &ty {
87		syn::Type::Array(_arr) => SupportedType::Array(ty),
88		syn::Type::Path(path) if !path.path.segments.is_empty() => {
89			let type_name = &path.path.segments.last().unwrap().ident;
90			if type_name == "Vec" {
91				SupportedType::Vec
92			} else if type_name == "BoundedVec" {
93				SupportedType::BoundedVec(ty)
94			} else {
95				return quote! { compile_error!("byte_string needs to wrap an array or (bounded) vec") }.into();
96			}
97		},
98		_ => {
99			return quote! { compile_error!("byte_string needs to wrap an array or (bounded) vec") }
100				.into();
101		},
102	};
103
104	let mut gen_token_stream = quote! {
105		#ast
106	};
107
108	for attr in attr.into_iter().map(|attr| attr.to_string()) {
109		let chunk: Box<dyn ToTokens> = match attr.as_str() {
110			"debug" => Box::from(gen_debug(name, generics)),
111			"hex_serialize" => Box::from(gen_hex_serialize(name, generics)),
112			"hex_deserialize" => Box::from(gen_hex_deserialize(name, &supported_type, generics)),
113			"from_num" => Box::from(gen_from_num(name, &supported_type, generics)),
114			"from_bytes" => Box::from(gen_from_bytes(name, &supported_type, generics)),
115			"decode_hex" => Box::from(gen_from_hex(name, &supported_type, generics)),
116			"to_hex_string" => Box::from(gen_to_hex(name, generics)),
117			"as_ref" => Box::from(gen_as_ref(name, generics)),
118			"," => continue,
119			_other => return quote! { compile_error!("Incorrect byte_string option") }.into(),
120		};
121
122		chunk.to_tokens(&mut gen_token_stream)
123	}
124
125	gen_token_stream.into()
126}
127
128fn gen_debug(name: &syn::Ident, generics: &Generics) -> impl ToTokens {
129	let format_str = format!("{}({{hex}})", name);
130	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
131	quote! {
132		impl #impl_generics core::fmt::Debug for #name #ty_generics #where_clause {
133			fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
134				let hex = sp_core::bytes::to_hex(&self.0, true);
135				return f.write_str(&alloc::format!(#format_str));
136			}
137		}
138	}
139}
140
141fn gen_hex_serialize(name: &syn::Ident, generics: &Generics) -> impl ToTokens {
142	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
143
144	quote! {
145		impl #impl_generics serde::Serialize for #name #ty_generics #where_clause {
146			fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
147			where
148				S: serde::Serializer,
149			{
150				let s = sp_core::bytes::to_hex(self.0.as_slice(), false);
151				serializer.serialize_str(s.as_str())
152			}
153		}
154	}
155}
156
157fn gen_hex_deserialize(
158	name: &syn::Ident,
159	ty: &SupportedType,
160	generics: &Generics,
161) -> impl ToTokens {
162	let type_params = generics.params.clone().into_iter();
163	let (_, ty_generics, where_clause) = generics.split_for_impl();
164
165	let created = match ty {
166		SupportedType::Array(ty) => {
167			quote! {
168				#name(<#ty>::try_from(inner)
169					  .map_err(|err| serde::de::Error::custom("Can't deserialize"))?)
170			}
171		},
172		SupportedType::BoundedVec(ty) => {
173			quote! {
174				#name(<#ty>::try_from(inner)
175					  .map_err(|_| serde::de::Error::custom("Invalid length"))?)
176			}
177		},
178		_ => quote! { #name(inner) },
179	};
180	quote! {
181		impl<'de, #(#type_params),* > serde::Deserialize < 'de > for #name #ty_generics #where_clause {
182			fn deserialize < D > (deserializer: D) -> Result < Self, D::Error >
183			where
184				D: serde::Deserializer < 'de >,
185			{
186				use alloc::string::ToString;
187				let str = <alloc::string::String>::deserialize(deserializer)?;
188				let inner = sp_core::bytes::from_hex(&str).map_err( | err | serde::de::Error::custom(err.to_string()))?;
189				Ok(#created)
190			}
191		}
192	}
193}
194
195fn gen_from_num(name: &syn::Ident, ty: &SupportedType, generics: &Generics) -> impl ToTokens {
196	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
197
198	match ty {
199		SupportedType::Array(ty) => quote! {
200			impl #impl_generics From<u64> for #name #ty_generics #where_clause {
201				fn from(n: u64) -> Self {
202					let mut ret = <#ty>::default();
203					let ret_len = ret.len();
204					let bytes = n.to_be_bytes();
205					ret[(ret_len-bytes.len())..].copy_from_slice(&bytes);
206					#name(ret)
207				}
208			}
209		},
210		_ => quote! {
211			impl #impl_generics From<u64> for #name #ty_generics #where_clause {
212				fn from(n: u64) -> Self {
213					#name(n.to_be_bytes().to_vec())
214				}
215			}
216		},
217	}
218}
219
220fn gen_from_bytes(name: &syn::Ident, ty: &SupportedType, generics: &Generics) -> impl ToTokens {
221	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
222
223	match ty {
224		SupportedType::Array(ty) => quote! {
225			impl<'a> TryFrom<&'a [u8]> for #name {
226				type Error = <#ty as TryFrom<&'a [u8]>>::Error;
227				fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
228					Ok(#name(bytes.try_into()?))
229				}
230			}
231		},
232		SupportedType::BoundedVec(ty) => quote! {
233			impl<'a> TryFrom<&'a [u8]> for #name {
234				type Error = <#ty as TryFrom<Vec<u8>>>::Error;
235				fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
236					Ok(#name(bytes.to_vec().try_into()?))
237				}
238			}
239		},
240		SupportedType::Vec => quote! {
241			impl #impl_generics From<&[u8]> for #name #ty_generics #where_clause {
242				fn from(bytes: &[u8]) -> Self {
243					#name(bytes.clone().to_vec())
244				}
245			}
246		},
247	}
248}
249
250fn gen_from_hex(name: &syn::Ident, ty: &SupportedType, generics: &Generics) -> impl ToTokens {
251	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
252
253	let decode_hex = match ty {
254		SupportedType::Array(ty) => quote! {
255			pub fn decode_hex(s: &str) -> Result<Self, &'static str> {
256				let value = <#ty>::try_from(sp_core::bytes::from_hex(s).map_err(|_| "Cannot decode bytes from hex string")?)
257					.map_err(|_| "Invalid length")?;
258				Ok(#name(value))
259			}
260		},
261		SupportedType::BoundedVec(ty) => quote! {
262			pub fn decode_hex(s: &str) -> Result<Self, &'static str> {
263				let bytes = sp_core::bytes::from_hex(s).map_err(|_| "Cannot decode bytes from hex string")?;
264				let value = <#ty>::try_from(bytes).map_err(|_| "Invalid length")?;
265				Ok(#name(value))
266			}
267
268		},
269		_ => quote! {
270			pub fn decode_hex(s: &str) -> Result<Self, &'static str> {
271				Ok(#name(sp_core::bytes::from_hex(s).map_err(|_| "Cannot decode bytes from hex string")?))
272			}
273		},
274	};
275
276	quote! {
277		#[allow(missing_docs)]
278		impl #impl_generics #name #ty_generics #where_clause {
279			#decode_hex
280
281			pub fn from_hex_unsafe(s: &str) -> Self {
282				Self::decode_hex(s).unwrap()
283			}
284
285		}
286
287		impl #impl_generics alloc::str::FromStr for #name #ty_generics #where_clause {
288			type Err = &'static str;
289			fn from_str(s: &str) -> Result<Self, Self::Err> {
290				Self::decode_hex(s)
291			}
292		}
293	}
294}
295
296fn gen_to_hex(name: &syn::Ident, generics: &Generics) -> impl ToTokens {
297	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
298
299	quote! {
300		#[allow(missing_docs)]
301		impl #impl_generics #name #ty_generics #where_clause {
302			pub fn to_hex_string(&self) -> alloc::string::String {
303				sp_core::bytes::to_hex(&self.0, false)
304			}
305		}
306	}
307}
308
309fn gen_as_ref(name: &syn::Ident, generics: &Generics) -> impl ToTokens {
310	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
311
312	quote! {
313		impl #impl_generics AsRef<[u8]> for #name #ty_generics #where_clause {
314			fn as_ref(&self) -> &[u8] {
315				&self.0
316			}
317		}
318	}
319}