diff --git a/core/src/cluster_nodes.rs b/core/src/cluster_nodes.rs index 0445d60cdb..2d1684fb75 100644 --- a/core/src/cluster_nodes.rs +++ b/core/src/cluster_nodes.rs @@ -410,23 +410,11 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo // Unstaked nodes will always appear at the very end. fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> { // Nodes are sorted by (stake, pubkey) in descending order. - let stakes: Vec = nodes - .iter() - .map(|node| node.stake) - .take_while(|stake| *stake > 0) - .collect(); - let num_staked = stakes.len(); - let mut out: Vec<_> = WeightedShuffle::new(rng, &stakes) + let stakes: Vec = nodes.iter().map(|node| node.stake).collect(); + WeightedShuffle::new(rng, &stakes) .unwrap() .map(|i| nodes[i]) - .collect(); - let weights = vec![1; nodes.len() - num_staked]; - out.extend( - WeightedShuffle::new(rng, &weights) - .unwrap() - .map(|i| nodes[i + num_staked]), - ); - out + .collect() } impl ClusterNodesCache { diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index ce6aee2c4b..3a87197c0c 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -26,12 +26,13 @@ pub enum WeightedShuffleError { /// - Returned indices are unique in the range [0, weights.len()). /// - Higher weighted indices tend to appear earlier proportional to their /// weight. -/// - Zero weighted indices are excluded. Therefore the iterator may have -/// count less than weights.len(). +/// - Zero weighted indices are shuffled and appear only at the end, after +/// non-zero weighted indices. pub struct WeightedShuffle<'a, R, T> { - arr: Vec, // Underlying array implementing binary indexed tree. - sum: T, // Current sum of weights, excluding already selected indices. - rng: &'a mut R, // Random number generator. + arr: Vec, // Underlying array implementing binary indexed tree. + sum: T, // Current sum of weights, excluding already selected indices. + zeros: Vec, // Indices of zero weighted entries. + rng: &'a mut R, // Random number generator. } // The implementation uses binary indexed tree: @@ -50,12 +51,17 @@ where let zero = ::default(); let mut arr = vec![zero; size]; let mut sum = zero; + let mut zeros = Vec::default(); for (mut k, &weight) in (1usize..).zip(weights) { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { return Err(WeightedShuffleError::NegativeWeight(weight)); } + if weight == zero { + zeros.push(k - 1); + continue; + } sum = sum .checked_add(&weight) .ok_or(WeightedShuffleError::SumOverflow)?; @@ -64,7 +70,12 @@ where k += k & k.wrapping_neg(); } } - Ok(Self { arr, sum, rng }) + Ok(Self { + arr, + sum, + rng, + zeros, + }) } } @@ -123,15 +134,22 @@ where fn next(&mut self) -> Option { let zero = ::default(); - #[allow(clippy::neg_cmp_op_on_partial_ord)] - // self.sum <= zero does not work for NaNs. - if !(self.sum > zero) { + if self.sum > zero { + let sample = + ::Sampler::sample_single(zero, self.sum, &mut self.rng); + let (index, weight) = WeightedShuffle::search(self, sample); + self.remove(index, weight); + return Some(index - 1); + } + if self.zeros.is_empty() { return None; } - let sample = ::Sampler::sample_single(zero, self.sum, &mut self.rng); - let (index, weight) = WeightedShuffle::search(self, sample); - self.remove(index, weight); - Some(index - 1) + let index = ::Sampler::sample_single( + 0usize, + self.zeros.len(), + &mut self.rng, + ); + Some(self.zeros.swap_remove(index)) } } @@ -142,9 +160,13 @@ where { let zero = ::default(); let high = cumulative_weights.last().copied().unwrap_or_default(); - #[allow(clippy::neg_cmp_op_on_partial_ord)] - if !(high > zero) { - return None; + if high == zero { + if cumulative_weights.is_empty() { + return None; + } + let index = + ::Sampler::sample_single(0usize, cumulative_weights.len(), rng); + return Some(index); } let sample = ::Sampler::sample_single(zero, high, rng); let mut lo = 0usize; @@ -234,11 +256,14 @@ mod tests { R: Rng, { let mut shuffle = Vec::with_capacity(weights.len()); - loop { - let high: u64 = weights.iter().sum(); - if high == 0 { - break shuffle; - } + let mut high: u64 = weights.iter().sum(); + let mut zeros: Vec<_> = weights + .iter() + .enumerate() + .filter(|(_, w)| **w == 0) + .map(|(i, _)| i) + .collect(); + while high != 0 { let sample = rng.gen_range(0, high); let index = weights .iter() @@ -249,8 +274,14 @@ mod tests { .position(|acc| sample < acc) .unwrap(); shuffle.push(index); + high -= weights[index]; weights[index] = 0; } + while !zeros.is_empty() { + let index = ::Sampler::sample_single(0usize, zeros.len(), rng); + shuffle.push(zeros.swap_remove(index)); + } + shuffle } #[test] @@ -329,14 +360,15 @@ mod tests { assert_eq!(weighted_sample_single(&mut rng, &weights), None); } - // Asserts that zero weights will return empty shuffle. + // Asserts that zero weights will be shuffled. #[test] fn test_weighted_shuffle_zero_weights() { let weights = vec![0u64; 5]; - let mut rng = rand::thread_rng(); - let shuffle = WeightedShuffle::new(&mut rng, &weights); - assert!(shuffle.unwrap().next().is_none()); - assert_eq!(weighted_sample_single(&mut rng, &weights), None); + let seed = [37u8; 32]; + let mut rng = ChaChaRng::from_seed(seed); + let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + assert_eq!(shuffle, [1, 4, 2, 3, 0]); + assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1)); } // Asserts that each index is selected proportional to its weight. @@ -352,13 +384,13 @@ mod tests { counts[shuffle.next().unwrap()] += 1; let _ = shuffle.count(); // consume the rest. } - assert_eq!(counts, [101, 0, 90113, 0, 0, 891, 8895, 0]); + assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]); } #[test] fn test_weighted_shuffle_hard_coded() { let weights = [ - 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72, + 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72, ]; let cumulative_weights: Vec<_> = weights .iter() @@ -372,7 +404,7 @@ mod tests { let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); assert_eq!( shuffle, - [2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8] + [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!( @@ -384,12 +416,12 @@ mod tests { let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); assert_eq!( shuffle, - [17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12] + [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!( weighted_sample_single(&mut rng, &cumulative_weights), - Some(17), + Some(19), ); }