selection/
weighted_random.rs1extern crate alloc;
2
3use crate::*;
4use alloc::vec::Vec;
5use rand::Rng;
6use rand_chacha::{ChaCha20Rng, rand_core::SeedableRng};
7
8pub fn select_authorities<T: Clone>(
13 weighted_candidates: Vec<(T, Weight)>,
14 seed: <ChaCha20Rng as SeedableRng>::Seed,
15 size: u16,
16) -> Option<Vec<T>> {
17 let size = usize::from(size);
18 let total_weight: Weight = weighted_candidates.iter().map(|(_, weight)| weight).sum();
19
20 let mut committee: Vec<T> = alloc::vec![];
21
22 let mut rng = ChaCha20Rng::from_seed(seed);
23
24 while committee.len() < size && !weighted_candidates.is_empty() {
25 let selected_index = select_with_weight(&weighted_candidates, total_weight, &mut rng);
26 let selected = weighted_candidates[selected_index].0.clone();
27 committee.push(selected);
28 }
29
30 if size <= committee.len() { Some(committee) } else { None }
31}
32
33fn select_with_weight<T>(
34 candidates: &[(T, Weight)],
35 total_weight: Weight,
36 rand: &mut ChaCha20Rng,
37) -> usize {
38 let random_number: u128 = rand.random_range(0..total_weight);
39
40 let mut cumulative_weight: Weight = 0;
41 for (index, (_, weight)) in candidates.iter().enumerate() {
42 cumulative_weight += weight;
43 if cumulative_weight > random_number {
44 return index;
45 }
46 }
47
48 panic!("Did not select any candidate");
49}
50
51#[cfg(test)]
52mod tests {
53 use super::*;
54 use crate::tests::*;
55 use quickcheck_macros::*;
56
57 type CandidatesWithWeights = Vec<(String, Weight)>;
58
59 #[derive(Clone)]
60 struct TestWeightedCandidates(CandidatesWithWeights, [u8; 32]);
61
62 fn select<const COMMITTEE_SIZE: u16>(
63 candidates: TestWeightedCandidates,
64 ) -> Option<Vec<String>> {
65 select_authorities(candidates.0, candidates.1, COMMITTEE_SIZE)
66 }
67
68 fn uniform_weight_candidates(n: u16) -> (Vec<String>, CandidatesWithWeights) {
69 let candidates = (0..n)
70 .map(|c| "candidate_".to_string() + &c.to_string())
71 .collect::<Vec<String>>();
72 let with_weights = candidates.iter().cloned().map(|c| (c, 1)).collect();
73 (candidates, with_weights)
74 }
75
76 const MAX_CANDIDATE_NUMBER: u16 = 1000;
77
78 #[quickcheck]
79 fn random_selection_with_repetition(candidate_number: u16, nonce: TestNonce) {
80 const COMMITTEE_SIZE: u16 = 2;
81 let candidate_number =
82 candidate_number % (MAX_CANDIDATE_NUMBER - COMMITTEE_SIZE) + COMMITTEE_SIZE;
83
84 let (candidates, candidates_with_weights) = uniform_weight_candidates(candidate_number);
85
86 let selection_data = TestWeightedCandidates(candidates_with_weights, nonce.0);
87
88 let Some(committee) = select::<COMMITTEE_SIZE>(selection_data) else {
89 panic!("select returned a None")
90 };
91
92 assert_eq!(committee.len(), COMMITTEE_SIZE as usize);
93 assert_subset!(String, committee, candidates);
94 }
95
96 #[quickcheck]
97 fn random_selection_zero_weight(nonce: TestNonce) {
98 let zero = "zero_weight".to_string();
99 let non_zero_1 = "non_zero_weight_1".to_string();
100 let non_zero_2 = "non_zero_weight_2".to_string();
101 let candidates = TestWeightedCandidates(
102 vec![(zero, 0), (non_zero_1.clone(), 1), (non_zero_2.clone(), 2)],
103 nonce.0,
104 );
105
106 let committee = select::<1>(candidates).unwrap();
107
108 assert!(committee == vec![non_zero_1] || committee == vec![non_zero_2]);
109 }
110
111 #[quickcheck]
112 fn random_selection_cannot_select_from_empty_candidates(nonce: TestNonce) {
113 let candidates = TestWeightedCandidates(vec![], nonce.0);
114
115 let committee = select::<1>(candidates);
116
117 assert_eq!(committee, None)
118 }
119
120 #[test]
121 fn etcm_5304_random_selection_should_not_be_skewed() {
122 let mut a_count = 0;
123 let candidates = vec![("a".to_string(), u128::MAX / 3), ("b".to_string(), u128::MAX / 3)];
124 for i in 0..1000u16 {
125 let i_bytes: [u8; 2] = i.to_be_bytes();
126 let mut nonce: [u8; 32] = [0u8; 32];
127 nonce[0] = i_bytes[0];
128 nonce[1] = i_bytes[1];
129 let input = TestWeightedCandidates(candidates.clone(), nonce);
130 let selected = select::<1>(input).unwrap();
131 if selected.contains(&"a".to_string()) {
132 a_count += 1;
133 }
134 }
135 assert!(a_count > 470 && a_count < 530)
136 }
137}