includes zero weighted entries in WeightedShuffle (#22829)
Current WeightedShuffle implementation excludes zero weighted entries from the shuffle: https://github.com/solana-labs/solana/blob/13e631dcf/gossip/src/weighted_shuffle.rs#L29-L30 Though mathematically this might make more sense, for our use-cases (turbine specifically), this results in less efficient code: https://github.com/solana-labs/solana/blob/13e631dcf/core/src/cluster_nodes.rs#L409-L430 This commit changes the implementation so that zero weighted indices are also included in the shuffle but appear only at the end after non-zero weighted indices.
This commit is contained in:
		| @@ -410,23 +410,11 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo | |||||||
| // Unstaked nodes will always appear at the very end. | // Unstaked nodes will always appear at the very end. | ||||||
| fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> { | fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> { | ||||||
|     // Nodes are sorted by (stake, pubkey) in descending order. |     // Nodes are sorted by (stake, pubkey) in descending order. | ||||||
|     let stakes: Vec<u64> = nodes |     let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect(); | ||||||
|         .iter() |     WeightedShuffle::new(rng, &stakes) | ||||||
|         .map(|node| node.stake) |  | ||||||
|         .take_while(|stake| *stake > 0) |  | ||||||
|         .collect(); |  | ||||||
|     let num_staked = stakes.len(); |  | ||||||
|     let mut out: Vec<_> = WeightedShuffle::new(rng, &stakes) |  | ||||||
|         .unwrap() |         .unwrap() | ||||||
|         .map(|i| nodes[i]) |         .map(|i| nodes[i]) | ||||||
|         .collect(); |         .collect() | ||||||
|     let weights = vec![1; nodes.len() - num_staked]; |  | ||||||
|     out.extend( |  | ||||||
|         WeightedShuffle::new(rng, &weights) |  | ||||||
|             .unwrap() |  | ||||||
|             .map(|i| nodes[i + num_staked]), |  | ||||||
|     ); |  | ||||||
|     out |  | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<T> ClusterNodesCache<T> { | impl<T> ClusterNodesCache<T> { | ||||||
|   | |||||||
| @@ -26,11 +26,12 @@ pub enum WeightedShuffleError<T> { | |||||||
| ///   - Returned indices are unique in the range [0, weights.len()). | ///   - Returned indices are unique in the range [0, weights.len()). | ||||||
| ///   - Higher weighted indices tend to appear earlier proportional to their | ///   - Higher weighted indices tend to appear earlier proportional to their | ||||||
| ///     weight. | ///     weight. | ||||||
| ///   - Zero weighted indices are excluded. Therefore the iterator may have | ///   - Zero weighted indices are shuffled and appear only at the end, after | ||||||
| ///     count less than weights.len(). | ///     non-zero weighted indices. | ||||||
| pub struct WeightedShuffle<'a, R, T> { | pub struct WeightedShuffle<'a, R, T> { | ||||||
|     arr: Vec<T>,       // Underlying array implementing binary indexed tree. |     arr: Vec<T>,       // Underlying array implementing binary indexed tree. | ||||||
|     sum: T,            // Current sum of weights, excluding already selected indices. |     sum: T,            // Current sum of weights, excluding already selected indices. | ||||||
|  |     zeros: Vec<usize>, // Indices of zero weighted entries. | ||||||
|     rng: &'a mut R,    // Random number generator. |     rng: &'a mut R,    // Random number generator. | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -50,12 +51,17 @@ where | |||||||
|         let zero = <T as Default>::default(); |         let zero = <T as Default>::default(); | ||||||
|         let mut arr = vec![zero; size]; |         let mut arr = vec![zero; size]; | ||||||
|         let mut sum = zero; |         let mut sum = zero; | ||||||
|  |         let mut zeros = Vec::default(); | ||||||
|         for (mut k, &weight) in (1usize..).zip(weights) { |         for (mut k, &weight) in (1usize..).zip(weights) { | ||||||
|             #[allow(clippy::neg_cmp_op_on_partial_ord)] |             #[allow(clippy::neg_cmp_op_on_partial_ord)] | ||||||
|             // weight < zero does not work for NaNs. |             // weight < zero does not work for NaNs. | ||||||
|             if !(weight >= zero) { |             if !(weight >= zero) { | ||||||
|                 return Err(WeightedShuffleError::NegativeWeight(weight)); |                 return Err(WeightedShuffleError::NegativeWeight(weight)); | ||||||
|             } |             } | ||||||
|  |             if weight == zero { | ||||||
|  |                 zeros.push(k - 1); | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|             sum = sum |             sum = sum | ||||||
|                 .checked_add(&weight) |                 .checked_add(&weight) | ||||||
|                 .ok_or(WeightedShuffleError::SumOverflow)?; |                 .ok_or(WeightedShuffleError::SumOverflow)?; | ||||||
| @@ -64,7 +70,12 @@ where | |||||||
|                 k += k & k.wrapping_neg(); |                 k += k & k.wrapping_neg(); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         Ok(Self { arr, sum, rng }) |         Ok(Self { | ||||||
|  |             arr, | ||||||
|  |             sum, | ||||||
|  |             rng, | ||||||
|  |             zeros, | ||||||
|  |         }) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -123,15 +134,22 @@ where | |||||||
|  |  | ||||||
|     fn next(&mut self) -> Option<Self::Item> { |     fn next(&mut self) -> Option<Self::Item> { | ||||||
|         let zero = <T as Default>::default(); |         let zero = <T as Default>::default(); | ||||||
|         #[allow(clippy::neg_cmp_op_on_partial_ord)] |         if self.sum > zero { | ||||||
|         // self.sum <= zero does not work for NaNs. |             let sample = | ||||||
|         if !(self.sum > zero) { |                 <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng); | ||||||
|             return None; |  | ||||||
|         } |  | ||||||
|         let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng); |  | ||||||
|             let (index, weight) = WeightedShuffle::search(self, sample); |             let (index, weight) = WeightedShuffle::search(self, sample); | ||||||
|             self.remove(index, weight); |             self.remove(index, weight); | ||||||
|         Some(index - 1) |             return Some(index - 1); | ||||||
|  |         } | ||||||
|  |         if self.zeros.is_empty() { | ||||||
|  |             return None; | ||||||
|  |         } | ||||||
|  |         let index = <usize as SampleUniform>::Sampler::sample_single( | ||||||
|  |             0usize, | ||||||
|  |             self.zeros.len(), | ||||||
|  |             &mut self.rng, | ||||||
|  |         ); | ||||||
|  |         Some(self.zeros.swap_remove(index)) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -142,10 +160,14 @@ where | |||||||
| { | { | ||||||
|     let zero = <T as Default>::default(); |     let zero = <T as Default>::default(); | ||||||
|     let high = cumulative_weights.last().copied().unwrap_or_default(); |     let high = cumulative_weights.last().copied().unwrap_or_default(); | ||||||
|     #[allow(clippy::neg_cmp_op_on_partial_ord)] |     if high == zero { | ||||||
|     if !(high > zero) { |         if cumulative_weights.is_empty() { | ||||||
|             return None; |             return None; | ||||||
|         } |         } | ||||||
|  |         let index = | ||||||
|  |             <usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng); | ||||||
|  |         return Some(index); | ||||||
|  |     } | ||||||
|     let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng); |     let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng); | ||||||
|     let mut lo = 0usize; |     let mut lo = 0usize; | ||||||
|     let mut hi = cumulative_weights.len() - 1; |     let mut hi = cumulative_weights.len() - 1; | ||||||
| @@ -234,11 +256,14 @@ mod tests { | |||||||
|         R: Rng, |         R: Rng, | ||||||
|     { |     { | ||||||
|         let mut shuffle = Vec::with_capacity(weights.len()); |         let mut shuffle = Vec::with_capacity(weights.len()); | ||||||
|         loop { |         let mut high: u64 = weights.iter().sum(); | ||||||
|             let high: u64 = weights.iter().sum(); |         let mut zeros: Vec<_> = weights | ||||||
|             if high == 0 { |             .iter() | ||||||
|                 break shuffle; |             .enumerate() | ||||||
|             } |             .filter(|(_, w)| **w == 0) | ||||||
|  |             .map(|(i, _)| i) | ||||||
|  |             .collect(); | ||||||
|  |         while high != 0 { | ||||||
|             let sample = rng.gen_range(0, high); |             let sample = rng.gen_range(0, high); | ||||||
|             let index = weights |             let index = weights | ||||||
|                 .iter() |                 .iter() | ||||||
| @@ -249,8 +274,14 @@ mod tests { | |||||||
|                 .position(|acc| sample < acc) |                 .position(|acc| sample < acc) | ||||||
|                 .unwrap(); |                 .unwrap(); | ||||||
|             shuffle.push(index); |             shuffle.push(index); | ||||||
|  |             high -= weights[index]; | ||||||
|             weights[index] = 0; |             weights[index] = 0; | ||||||
|         } |         } | ||||||
|  |         while !zeros.is_empty() { | ||||||
|  |             let index = <usize as SampleUniform>::Sampler::sample_single(0usize, zeros.len(), rng); | ||||||
|  |             shuffle.push(zeros.swap_remove(index)); | ||||||
|  |         } | ||||||
|  |         shuffle | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
| @@ -329,14 +360,15 @@ mod tests { | |||||||
|         assert_eq!(weighted_sample_single(&mut rng, &weights), None); |         assert_eq!(weighted_sample_single(&mut rng, &weights), None); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Asserts that zero weights will return empty shuffle. |     // Asserts that zero weights will be shuffled. | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_weighted_shuffle_zero_weights() { |     fn test_weighted_shuffle_zero_weights() { | ||||||
|         let weights = vec![0u64; 5]; |         let weights = vec![0u64; 5]; | ||||||
|         let mut rng = rand::thread_rng(); |         let seed = [37u8; 32]; | ||||||
|         let shuffle = WeightedShuffle::new(&mut rng, &weights); |         let mut rng = ChaChaRng::from_seed(seed); | ||||||
|         assert!(shuffle.unwrap().next().is_none()); |         let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); | ||||||
|         assert_eq!(weighted_sample_single(&mut rng, &weights), None); |         assert_eq!(shuffle, [1, 4, 2, 3, 0]); | ||||||
|  |         assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Asserts that each index is selected proportional to its weight. |     // Asserts that each index is selected proportional to its weight. | ||||||
| @@ -352,13 +384,13 @@ mod tests { | |||||||
|             counts[shuffle.next().unwrap()] += 1; |             counts[shuffle.next().unwrap()] += 1; | ||||||
|             let _ = shuffle.count(); // consume the rest. |             let _ = shuffle.count(); // consume the rest. | ||||||
|         } |         } | ||||||
|         assert_eq!(counts, [101, 0, 90113, 0, 0, 891, 8895, 0]); |         assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_weighted_shuffle_hard_coded() { |     fn test_weighted_shuffle_hard_coded() { | ||||||
|         let weights = [ |         let weights = [ | ||||||
|             78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72, |             78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72, | ||||||
|         ]; |         ]; | ||||||
|         let cumulative_weights: Vec<_> = weights |         let cumulative_weights: Vec<_> = weights | ||||||
|             .iter() |             .iter() | ||||||
| @@ -372,7 +404,7 @@ mod tests { | |||||||
|         let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); |         let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); | ||||||
|         assert_eq!( |         assert_eq!( | ||||||
|             shuffle, |             shuffle, | ||||||
|             [2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8] |             [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] | ||||||
|         ); |         ); | ||||||
|         let mut rng = ChaChaRng::from_seed(seed); |         let mut rng = ChaChaRng::from_seed(seed); | ||||||
|         assert_eq!( |         assert_eq!( | ||||||
| @@ -384,12 +416,12 @@ mod tests { | |||||||
|         let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); |         let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); | ||||||
|         assert_eq!( |         assert_eq!( | ||||||
|             shuffle, |             shuffle, | ||||||
|             [17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12] |             [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] | ||||||
|         ); |         ); | ||||||
|         let mut rng = ChaChaRng::from_seed(seed); |         let mut rng = ChaChaRng::from_seed(seed); | ||||||
|         assert_eq!( |         assert_eq!( | ||||||
|             weighted_sample_single(&mut rng, &cumulative_weights), |             weighted_sample_single(&mut rng, &cumulative_weights), | ||||||
|             Some(17), |             Some(19), | ||||||
|         ); |         ); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user