diff --git a/core/src/cluster_nodes.rs b/core/src/cluster_nodes.rs index 2d1684fb75..a02bbdbfe1 100644 --- a/core/src/cluster_nodes.rs +++ b/core/src/cluster_nodes.rs @@ -411,8 +411,9 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> { // Nodes are sorted by (stake, pubkey) in descending order. let stakes: Vec = nodes.iter().map(|node| node.stake).collect(); - WeightedShuffle::new(rng, &stakes) + WeightedShuffle::new(&stakes) .unwrap() + .shuffle(rng) .map(|i| nodes[i]) .collect() } diff --git a/gossip/benches/weighted_shuffle.rs b/gossip/benches/weighted_shuffle.rs index d2b15f97d0..72f3b6dbcc 100644 --- a/gossip/benches/weighted_shuffle.rs +++ b/gossip/benches/weighted_shuffle.rs @@ -32,8 +32,9 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) { let weights = make_weights(&mut rng); bencher.iter(|| { rng.fill(&mut seed[..]); - WeightedShuffle::new(&mut ChaChaRng::from_seed(seed), &weights) - .unwrap() + let shuffle = WeightedShuffle::new(&weights).unwrap(); + shuffle + .shuffle(&mut ChaChaRng::from_seed(seed)) .collect::>() }); } diff --git a/gossip/src/cluster_info.rs b/gossip/src/cluster_info.rs index fc83ece801..069d970054 100644 --- a/gossip/src/cluster_info.rs +++ b/gossip/src/cluster_info.rs @@ -2010,7 +2010,7 @@ impl ClusterInfo { return packet_batch; } let mut rng = rand::thread_rng(); - let shuffle = WeightedShuffle::new(&mut rng, &scores).unwrap(); + let shuffle = WeightedShuffle::new(&scores).unwrap().shuffle(&mut rng); let mut total_bytes = 0; let mut sent = 0; for (addr, response) in shuffle.map(|i| &responses[i]) { diff --git a/gossip/src/crds_gossip_pull.rs b/gossip/src/crds_gossip_pull.rs index 4e0d8bcd61..30ebe48328 100644 --- a/gossip/src/crds_gossip_pull.rs +++ b/gossip/src/crds_gossip_pull.rs @@ -246,8 +246,9 @@ impl CrdsGossipPull { return Err(CrdsGossipError::NoPeers); } let mut rng = rand::thread_rng(); - let mut peers = WeightedShuffle::new(&mut rng, &weights) + let mut peers = WeightedShuffle::new(&weights) .unwrap() + .shuffle(&mut rng) .map(|i| peers[i]); let peer = { let mut rng = rand::thread_rng(); diff --git a/gossip/src/crds_gossip_push.rs b/gossip/src/crds_gossip_push.rs index 98a618b178..cde6ff442f 100644 --- a/gossip/src/crds_gossip_push.rs +++ b/gossip/src/crds_gossip_push.rs @@ -169,8 +169,9 @@ impl CrdsGossipPush { .filter(|(_, stake)| *stake > 0) .collect(); let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect(); - WeightedShuffle::new(&mut rng, &weights) + WeightedShuffle::new(&weights) .unwrap() + .shuffle(&mut rng) .map(move |i| peers[i]) }; let mut keep = HashSet::new(); @@ -369,7 +370,7 @@ impl CrdsGossipPush { return; } let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size); - let shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); + let shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng); let mut active_set = self.active_set.write().unwrap(); let need = Self::compute_need(self.num_active, active_set.len(), ratio); for peer in shuffle.map(|i| peers[i]) { diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 3a87197c0c..f1d73e4168 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -28,25 +28,24 @@ pub enum WeightedShuffleError { /// weight. /// - Zero weighted indices are shuffled and appear only at the end, after /// non-zero weighted indices. -pub struct WeightedShuffle<'a, R, T> { +pub struct WeightedShuffle { arr: Vec, // Underlying array implementing binary indexed tree. sum: T, // Current sum of weights, excluding already selected indices. zeros: Vec, // Indices of zero weighted entries. - rng: &'a mut R, // Random number generator. } // The implementation uses binary indexed tree: // https://en.wikipedia.org/wiki/Fenwick_tree // to maintain cumulative sum of weights excluding already selected indices // over self.arr. -impl<'a, R: Rng, T> WeightedShuffle<'a, R, T> +impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, { /// Returns error if: /// - any of the weights are negative. /// - sum of weights overflows. - pub fn new(rng: &'a mut R, weights: &[T]) -> Result> { + pub fn new(weights: &[T]) -> Result> { let size = weights.len() + 1; let zero = ::default(); let mut arr = vec![zero; size]; @@ -70,16 +69,11 @@ where k += k & k.wrapping_neg(); } } - Ok(Self { - arr, - sum, - rng, - zeros, - }) + Ok(Self { arr, sum, zeros }) } } -impl<'a, R, T> WeightedShuffle<'a, R, T> +impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, { @@ -126,30 +120,26 @@ where } } -impl<'a, R: Rng, T> Iterator for WeightedShuffle<'a, R, T> +impl<'a, T: 'a> WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, { - type Item = usize; - - fn next(&mut self) -> Option { - let zero = ::default(); - if self.sum > zero { - let sample = - ::Sampler::sample_single(zero, self.sum, &mut self.rng); - let (index, weight) = WeightedShuffle::search(self, sample); - self.remove(index, weight); - return Some(index - 1); - } - if self.zeros.is_empty() { - return None; - } - let index = ::Sampler::sample_single( - 0usize, - self.zeros.len(), - &mut self.rng, - ); - Some(self.zeros.swap_remove(index)) + pub fn shuffle(mut self, rng: &'a mut R) -> impl Iterator + 'a { + std::iter::from_fn(move || { + let zero = ::default(); + if self.sum > zero { + let sample = ::Sampler::sample_single(zero, self.sum, rng); + let (index, weight) = WeightedShuffle::search(&self, sample); + self.remove(index, weight); + return Some(index - 1); + } + if self.zeros.is_empty() { + return None; + } + let index = + ::Sampler::sample_single(0usize, self.zeros.len(), rng); + Some(self.zeros.swap_remove(index)) + }) } } @@ -355,8 +345,8 @@ mod tests { fn test_weighted_shuffle_empty_weights() { let weights = Vec::::new(); let mut rng = rand::thread_rng(); - let shuffle = WeightedShuffle::new(&mut rng, &weights); - assert!(shuffle.unwrap().next().is_none()); + let shuffle = WeightedShuffle::new(&weights).unwrap(); + assert!(shuffle.shuffle(&mut rng).next().is_none()); assert_eq!(weighted_sample_single(&mut rng, &weights), None); } @@ -366,7 +356,8 @@ mod tests { let weights = vec![0u64; 5]; let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); assert_eq!(shuffle, [1, 4, 2, 3, 0]); assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1)); } @@ -380,7 +371,7 @@ mod tests { let weights = [1, 0, 1000, 0, 0, 10, 100, 0]; let mut counts = [0; 8]; for _ in 0..100000 { - let mut shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); + let mut shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng); counts[shuffle.next().unwrap()] += 1; let _ = shuffle.count(); // consume the rest. } @@ -401,7 +392,8 @@ mod tests { .collect(); let seed = [48u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); assert_eq!( shuffle, [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] @@ -413,7 +405,8 @@ mod tests { ); let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); assert_eq!( shuffle, [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] @@ -440,7 +433,8 @@ mod tests { let mut seed = [0u8; 32]; rng.fill(&mut seed[..]); let mut rng = ChaChaRng::from_seed(seed); - let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); let mut rng = ChaChaRng::from_seed(seed); let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); assert_eq!(shuffle, shuffle_slow);