selection/
weighted_random.rs

1extern crate alloc;
2
3use crate::*;
4use alloc::vec::Vec;
5use rand::Rng;
6use rand_chacha::{ChaCha20Rng, rand_core::SeedableRng};
7
8/// Simple random weighted selection
9///
10/// When selecting out of `n` candidates with weights `w_1`, `w_2`, ..., `w_n`, independently assigns each
11/// committee seat to the k-th candidate with probability `w_k / (w_1 + w_2 + ... + w_n)`.
12pub fn select_authorities<T: Clone>(
13	weighted_candidates: Vec<(T, Weight)>,
14	seed: <ChaCha20Rng as SeedableRng>::Seed,
15	size: u16,
16) -> Option<Vec<T>> {
17	let size = usize::from(size);
18	let total_weight: Weight = weighted_candidates.iter().map(|(_, weight)| weight).sum();
19
20	let mut committee: Vec<T> = alloc::vec![];
21
22	let mut rng = ChaCha20Rng::from_seed(seed);
23
24	while committee.len() < size && !weighted_candidates.is_empty() {
25		let selected_index = select_with_weight(&weighted_candidates, total_weight, &mut rng);
26		let selected = weighted_candidates[selected_index].0.clone();
27		committee.push(selected);
28	}
29
30	if size <= committee.len() { Some(committee) } else { None }
31}
32
33fn select_with_weight<T>(
34	candidates: &[(T, Weight)],
35	total_weight: Weight,
36	rand: &mut ChaCha20Rng,
37) -> usize {
38	let random_number: u128 = rand.random_range(0..total_weight);
39
40	let mut cumulative_weight: Weight = 0;
41	for (index, (_, weight)) in candidates.iter().enumerate() {
42		cumulative_weight += weight;
43		if cumulative_weight > random_number {
44			return index;
45		}
46	}
47
48	panic!("Did not select any candidate");
49}
50
51#[cfg(test)]
52mod tests {
53	use super::*;
54	use crate::tests::*;
55	use quickcheck_macros::*;
56
57	type CandidatesWithWeights = Vec<(String, Weight)>;
58
59	#[derive(Clone)]
60	struct TestWeightedCandidates(CandidatesWithWeights, [u8; 32]);
61
62	fn select<const COMMITTEE_SIZE: u16>(
63		candidates: TestWeightedCandidates,
64	) -> Option<Vec<String>> {
65		select_authorities(candidates.0, candidates.1, COMMITTEE_SIZE)
66	}
67
68	fn uniform_weight_candidates(n: u16) -> (Vec<String>, CandidatesWithWeights) {
69		let candidates = (0..n)
70			.map(|c| "candidate_".to_string() + &c.to_string())
71			.collect::<Vec<String>>();
72		let with_weights = candidates.iter().cloned().map(|c| (c, 1)).collect();
73		(candidates, with_weights)
74	}
75
76	const MAX_CANDIDATE_NUMBER: u16 = 1000;
77
78	#[quickcheck]
79	fn random_selection_with_repetition(candidate_number: u16, nonce: TestNonce) {
80		const COMMITTEE_SIZE: u16 = 2;
81		let candidate_number =
82			candidate_number % (MAX_CANDIDATE_NUMBER - COMMITTEE_SIZE) + COMMITTEE_SIZE;
83
84		let (candidates, candidates_with_weights) = uniform_weight_candidates(candidate_number);
85
86		let selection_data = TestWeightedCandidates(candidates_with_weights, nonce.0);
87
88		let Some(committee) = select::<COMMITTEE_SIZE>(selection_data) else {
89			panic!("select returned a None")
90		};
91
92		assert_eq!(committee.len(), COMMITTEE_SIZE as usize);
93		assert_subset!(String, committee, candidates);
94	}
95
96	#[quickcheck]
97	fn random_selection_zero_weight(nonce: TestNonce) {
98		let zero = "zero_weight".to_string();
99		let non_zero_1 = "non_zero_weight_1".to_string();
100		let non_zero_2 = "non_zero_weight_2".to_string();
101		let candidates = TestWeightedCandidates(
102			vec![(zero, 0), (non_zero_1.clone(), 1), (non_zero_2.clone(), 2)],
103			nonce.0,
104		);
105
106		let committee = select::<1>(candidates).unwrap();
107
108		assert!(committee == vec![non_zero_1] || committee == vec![non_zero_2]);
109	}
110
111	#[quickcheck]
112	fn random_selection_cannot_select_from_empty_candidates(nonce: TestNonce) {
113		let candidates = TestWeightedCandidates(vec![], nonce.0);
114
115		let committee = select::<1>(candidates);
116
117		assert_eq!(committee, None)
118	}
119
120	#[test]
121	fn etcm_5304_random_selection_should_not_be_skewed() {
122		let mut a_count = 0;
123		let candidates = vec![("a".to_string(), u128::MAX / 3), ("b".to_string(), u128::MAX / 3)];
124		for i in 0..1000u16 {
125			let i_bytes: [u8; 2] = i.to_be_bytes();
126			let mut nonce: [u8; 32] = [0u8; 32];
127			nonce[0] = i_bytes[0];
128			nonce[1] = i_bytes[1];
129			let input = TestWeightedCandidates(candidates.clone(), nonce);
130			let selected = select::<1>(input).unwrap();
131			if selected.contains(&"a".to_string()) {
132				a_count += 1;
133			}
134		}
135		assert!(a_count > 470 && a_count < 530)
136	}
137}