includes zero weighted entries in WeightedShuffle (#22829)

Current WeightedShuffle implementation excludes zero weighted entries
from the shuffle:
https://github.com/solana-labs/solana/blob/13e631dcf/gossip/src/weighted_shuffle.rs#L29-L30

Though mathematically this might make more sense, for our use-cases
(turbine specifically), this results in less efficient code:
https://github.com/solana-labs/solana/blob/13e631dcf/core/src/cluster_nodes.rs#L409-L430

This commit changes the implementation so that zero weighted indices are
also included in the shuffle but appear only at the end after non-zero
weighted indices.
This commit is contained in:
behzad nouri
2022-01-31 16:23:50 +00:00
committed by GitHub
parent 17b4563a6f
commit 604ca9316c
2 changed files with 66 additions and 46 deletions

View File

@ -410,23 +410,11 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
// Unstaked nodes will always appear at the very end. // Unstaked nodes will always appear at the very end.
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> { fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
// Nodes are sorted by (stake, pubkey) in descending order. // Nodes are sorted by (stake, pubkey) in descending order.
let stakes: Vec<u64> = nodes let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
.iter() WeightedShuffle::new(rng, &stakes)
.map(|node| node.stake)
.take_while(|stake| *stake > 0)
.collect();
let num_staked = stakes.len();
let mut out: Vec<_> = WeightedShuffle::new(rng, &stakes)
.unwrap() .unwrap()
.map(|i| nodes[i]) .map(|i| nodes[i])
.collect(); .collect()
let weights = vec![1; nodes.len() - num_staked];
out.extend(
WeightedShuffle::new(rng, &weights)
.unwrap()
.map(|i| nodes[i + num_staked]),
);
out
} }
impl<T> ClusterNodesCache<T> { impl<T> ClusterNodesCache<T> {

View File

@ -26,11 +26,12 @@ pub enum WeightedShuffleError<T> {
/// - Returned indices are unique in the range [0, weights.len()). /// - Returned indices are unique in the range [0, weights.len()).
/// - Higher weighted indices tend to appear earlier proportional to their /// - Higher weighted indices tend to appear earlier proportional to their
/// weight. /// weight.
/// - Zero weighted indices are excluded. Therefore the iterator may have /// - Zero weighted indices are shuffled and appear only at the end, after
/// count less than weights.len(). /// non-zero weighted indices.
pub struct WeightedShuffle<'a, R, T> { pub struct WeightedShuffle<'a, R, T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree. arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices. sum: T, // Current sum of weights, excluding already selected indices.
zeros: Vec<usize>, // Indices of zero weighted entries.
rng: &'a mut R, // Random number generator. rng: &'a mut R, // Random number generator.
} }
@ -50,12 +51,17 @@ where
let zero = <T as Default>::default(); let zero = <T as Default>::default();
let mut arr = vec![zero; size]; let mut arr = vec![zero; size];
let mut sum = zero; let mut sum = zero;
let mut zeros = Vec::default();
for (mut k, &weight) in (1usize..).zip(weights) { for (mut k, &weight) in (1usize..).zip(weights) {
#[allow(clippy::neg_cmp_op_on_partial_ord)] #[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs. // weight < zero does not work for NaNs.
if !(weight >= zero) { if !(weight >= zero) {
return Err(WeightedShuffleError::NegativeWeight(weight)); return Err(WeightedShuffleError::NegativeWeight(weight));
} }
if weight == zero {
zeros.push(k - 1);
continue;
}
sum = sum sum = sum
.checked_add(&weight) .checked_add(&weight)
.ok_or(WeightedShuffleError::SumOverflow)?; .ok_or(WeightedShuffleError::SumOverflow)?;
@ -64,7 +70,12 @@ where
k += k & k.wrapping_neg(); k += k & k.wrapping_neg();
} }
} }
Ok(Self { arr, sum, rng }) Ok(Self {
arr,
sum,
rng,
zeros,
})
} }
} }
@ -123,15 +134,22 @@ where
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let zero = <T as Default>::default(); let zero = <T as Default>::default();
#[allow(clippy::neg_cmp_op_on_partial_ord)] if self.sum > zero {
// self.sum <= zero does not work for NaNs. let sample =
if !(self.sum > zero) { <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
return None;
}
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
let (index, weight) = WeightedShuffle::search(self, sample); let (index, weight) = WeightedShuffle::search(self, sample);
self.remove(index, weight); self.remove(index, weight);
Some(index - 1) return Some(index - 1);
}
if self.zeros.is_empty() {
return None;
}
let index = <usize as SampleUniform>::Sampler::sample_single(
0usize,
self.zeros.len(),
&mut self.rng,
);
Some(self.zeros.swap_remove(index))
} }
} }
@ -142,10 +160,14 @@ where
{ {
let zero = <T as Default>::default(); let zero = <T as Default>::default();
let high = cumulative_weights.last().copied().unwrap_or_default(); let high = cumulative_weights.last().copied().unwrap_or_default();
#[allow(clippy::neg_cmp_op_on_partial_ord)] if high == zero {
if !(high > zero) { if cumulative_weights.is_empty() {
return None; return None;
} }
let index =
<usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng);
return Some(index);
}
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng); let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
let mut lo = 0usize; let mut lo = 0usize;
let mut hi = cumulative_weights.len() - 1; let mut hi = cumulative_weights.len() - 1;
@ -234,11 +256,14 @@ mod tests {
R: Rng, R: Rng,
{ {
let mut shuffle = Vec::with_capacity(weights.len()); let mut shuffle = Vec::with_capacity(weights.len());
loop { let mut high: u64 = weights.iter().sum();
let high: u64 = weights.iter().sum(); let mut zeros: Vec<_> = weights
if high == 0 { .iter()
break shuffle; .enumerate()
} .filter(|(_, w)| **w == 0)
.map(|(i, _)| i)
.collect();
while high != 0 {
let sample = rng.gen_range(0, high); let sample = rng.gen_range(0, high);
let index = weights let index = weights
.iter() .iter()
@ -249,8 +274,14 @@ mod tests {
.position(|acc| sample < acc) .position(|acc| sample < acc)
.unwrap(); .unwrap();
shuffle.push(index); shuffle.push(index);
high -= weights[index];
weights[index] = 0; weights[index] = 0;
} }
while !zeros.is_empty() {
let index = <usize as SampleUniform>::Sampler::sample_single(0usize, zeros.len(), rng);
shuffle.push(zeros.swap_remove(index));
}
shuffle
} }
#[test] #[test]
@ -329,14 +360,15 @@ mod tests {
assert_eq!(weighted_sample_single(&mut rng, &weights), None); assert_eq!(weighted_sample_single(&mut rng, &weights), None);
} }
// Asserts that zero weights will return empty shuffle. // Asserts that zero weights will be shuffled.
#[test] #[test]
fn test_weighted_shuffle_zero_weights() { fn test_weighted_shuffle_zero_weights() {
let weights = vec![0u64; 5]; let weights = vec![0u64; 5];
let mut rng = rand::thread_rng(); let seed = [37u8; 32];
let shuffle = WeightedShuffle::new(&mut rng, &weights); let mut rng = ChaChaRng::from_seed(seed);
assert!(shuffle.unwrap().next().is_none()); let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!(weighted_sample_single(&mut rng, &weights), None); assert_eq!(shuffle, [1, 4, 2, 3, 0]);
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
} }
// Asserts that each index is selected proportional to its weight. // Asserts that each index is selected proportional to its weight.
@ -352,13 +384,13 @@ mod tests {
counts[shuffle.next().unwrap()] += 1; counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest. let _ = shuffle.count(); // consume the rest.
} }
assert_eq!(counts, [101, 0, 90113, 0, 0, 891, 8895, 0]); assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
} }
#[test] #[test]
fn test_weighted_shuffle_hard_coded() { fn test_weighted_shuffle_hard_coded() {
let weights = [ let weights = [
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72, 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
]; ];
let cumulative_weights: Vec<_> = weights let cumulative_weights: Vec<_> = weights
.iter() .iter()
@ -372,7 +404,7 @@ mod tests {
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!( assert_eq!(
shuffle, shuffle,
[2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8] [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
); );
let mut rng = ChaChaRng::from_seed(seed); let mut rng = ChaChaRng::from_seed(seed);
assert_eq!( assert_eq!(
@ -384,12 +416,12 @@ mod tests {
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!( assert_eq!(
shuffle, shuffle,
[17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12] [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
); );
let mut rng = ChaChaRng::from_seed(seed); let mut rng = ChaChaRng::from_seed(seed);
assert_eq!( assert_eq!(
weighted_sample_single(&mut rng, &cumulative_weights), weighted_sample_single(&mut rng, &cumulative_weights),
Some(17), Some(19),
); );
} }