includes zero weighted entries in WeightedShuffle (#22829)
Current WeightedShuffle implementation excludes zero weighted entries from the shuffle: https://github.com/solana-labs/solana/blob/13e631dcf/gossip/src/weighted_shuffle.rs#L29-L30 Though mathematically this might make more sense, for our use-cases (turbine specifically), this results in less efficient code: https://github.com/solana-labs/solana/blob/13e631dcf/core/src/cluster_nodes.rs#L409-L430 This commit changes the implementation so that zero weighted indices are also included in the shuffle but appear only at the end after non-zero weighted indices.
This commit is contained in:
@ -410,23 +410,11 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
|
|||||||
// Unstaked nodes will always appear at the very end.
|
// Unstaked nodes will always appear at the very end.
|
||||||
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
|
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
|
||||||
// Nodes are sorted by (stake, pubkey) in descending order.
|
// Nodes are sorted by (stake, pubkey) in descending order.
|
||||||
let stakes: Vec<u64> = nodes
|
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
|
||||||
.iter()
|
WeightedShuffle::new(rng, &stakes)
|
||||||
.map(|node| node.stake)
|
|
||||||
.take_while(|stake| *stake > 0)
|
|
||||||
.collect();
|
|
||||||
let num_staked = stakes.len();
|
|
||||||
let mut out: Vec<_> = WeightedShuffle::new(rng, &stakes)
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.map(|i| nodes[i])
|
.map(|i| nodes[i])
|
||||||
.collect();
|
.collect()
|
||||||
let weights = vec![1; nodes.len() - num_staked];
|
|
||||||
out.extend(
|
|
||||||
WeightedShuffle::new(rng, &weights)
|
|
||||||
.unwrap()
|
|
||||||
.map(|i| nodes[i + num_staked]),
|
|
||||||
);
|
|
||||||
out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> ClusterNodesCache<T> {
|
impl<T> ClusterNodesCache<T> {
|
||||||
|
@ -26,12 +26,13 @@ pub enum WeightedShuffleError<T> {
|
|||||||
/// - Returned indices are unique in the range [0, weights.len()).
|
/// - Returned indices are unique in the range [0, weights.len()).
|
||||||
/// - Higher weighted indices tend to appear earlier proportional to their
|
/// - Higher weighted indices tend to appear earlier proportional to their
|
||||||
/// weight.
|
/// weight.
|
||||||
/// - Zero weighted indices are excluded. Therefore the iterator may have
|
/// - Zero weighted indices are shuffled and appear only at the end, after
|
||||||
/// count less than weights.len().
|
/// non-zero weighted indices.
|
||||||
pub struct WeightedShuffle<'a, R, T> {
|
pub struct WeightedShuffle<'a, R, T> {
|
||||||
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
||||||
sum: T, // Current sum of weights, excluding already selected indices.
|
sum: T, // Current sum of weights, excluding already selected indices.
|
||||||
rng: &'a mut R, // Random number generator.
|
zeros: Vec<usize>, // Indices of zero weighted entries.
|
||||||
|
rng: &'a mut R, // Random number generator.
|
||||||
}
|
}
|
||||||
|
|
||||||
// The implementation uses binary indexed tree:
|
// The implementation uses binary indexed tree:
|
||||||
@ -50,12 +51,17 @@ where
|
|||||||
let zero = <T as Default>::default();
|
let zero = <T as Default>::default();
|
||||||
let mut arr = vec![zero; size];
|
let mut arr = vec![zero; size];
|
||||||
let mut sum = zero;
|
let mut sum = zero;
|
||||||
|
let mut zeros = Vec::default();
|
||||||
for (mut k, &weight) in (1usize..).zip(weights) {
|
for (mut k, &weight) in (1usize..).zip(weights) {
|
||||||
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
||||||
// weight < zero does not work for NaNs.
|
// weight < zero does not work for NaNs.
|
||||||
if !(weight >= zero) {
|
if !(weight >= zero) {
|
||||||
return Err(WeightedShuffleError::NegativeWeight(weight));
|
return Err(WeightedShuffleError::NegativeWeight(weight));
|
||||||
}
|
}
|
||||||
|
if weight == zero {
|
||||||
|
zeros.push(k - 1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
sum = sum
|
sum = sum
|
||||||
.checked_add(&weight)
|
.checked_add(&weight)
|
||||||
.ok_or(WeightedShuffleError::SumOverflow)?;
|
.ok_or(WeightedShuffleError::SumOverflow)?;
|
||||||
@ -64,7 +70,12 @@ where
|
|||||||
k += k & k.wrapping_neg();
|
k += k & k.wrapping_neg();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(Self { arr, sum, rng })
|
Ok(Self {
|
||||||
|
arr,
|
||||||
|
sum,
|
||||||
|
rng,
|
||||||
|
zeros,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,15 +134,22 @@ where
|
|||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let zero = <T as Default>::default();
|
let zero = <T as Default>::default();
|
||||||
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
if self.sum > zero {
|
||||||
// self.sum <= zero does not work for NaNs.
|
let sample =
|
||||||
if !(self.sum > zero) {
|
<T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
|
||||||
|
let (index, weight) = WeightedShuffle::search(self, sample);
|
||||||
|
self.remove(index, weight);
|
||||||
|
return Some(index - 1);
|
||||||
|
}
|
||||||
|
if self.zeros.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
|
let index = <usize as SampleUniform>::Sampler::sample_single(
|
||||||
let (index, weight) = WeightedShuffle::search(self, sample);
|
0usize,
|
||||||
self.remove(index, weight);
|
self.zeros.len(),
|
||||||
Some(index - 1)
|
&mut self.rng,
|
||||||
|
);
|
||||||
|
Some(self.zeros.swap_remove(index))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,9 +160,13 @@ where
|
|||||||
{
|
{
|
||||||
let zero = <T as Default>::default();
|
let zero = <T as Default>::default();
|
||||||
let high = cumulative_weights.last().copied().unwrap_or_default();
|
let high = cumulative_weights.last().copied().unwrap_or_default();
|
||||||
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
if high == zero {
|
||||||
if !(high > zero) {
|
if cumulative_weights.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
|
}
|
||||||
|
let index =
|
||||||
|
<usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng);
|
||||||
|
return Some(index);
|
||||||
}
|
}
|
||||||
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
|
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
|
||||||
let mut lo = 0usize;
|
let mut lo = 0usize;
|
||||||
@ -234,11 +256,14 @@ mod tests {
|
|||||||
R: Rng,
|
R: Rng,
|
||||||
{
|
{
|
||||||
let mut shuffle = Vec::with_capacity(weights.len());
|
let mut shuffle = Vec::with_capacity(weights.len());
|
||||||
loop {
|
let mut high: u64 = weights.iter().sum();
|
||||||
let high: u64 = weights.iter().sum();
|
let mut zeros: Vec<_> = weights
|
||||||
if high == 0 {
|
.iter()
|
||||||
break shuffle;
|
.enumerate()
|
||||||
}
|
.filter(|(_, w)| **w == 0)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.collect();
|
||||||
|
while high != 0 {
|
||||||
let sample = rng.gen_range(0, high);
|
let sample = rng.gen_range(0, high);
|
||||||
let index = weights
|
let index = weights
|
||||||
.iter()
|
.iter()
|
||||||
@ -249,8 +274,14 @@ mod tests {
|
|||||||
.position(|acc| sample < acc)
|
.position(|acc| sample < acc)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
shuffle.push(index);
|
shuffle.push(index);
|
||||||
|
high -= weights[index];
|
||||||
weights[index] = 0;
|
weights[index] = 0;
|
||||||
}
|
}
|
||||||
|
while !zeros.is_empty() {
|
||||||
|
let index = <usize as SampleUniform>::Sampler::sample_single(0usize, zeros.len(), rng);
|
||||||
|
shuffle.push(zeros.swap_remove(index));
|
||||||
|
}
|
||||||
|
shuffle
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -329,14 +360,15 @@ mod tests {
|
|||||||
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Asserts that zero weights will return empty shuffle.
|
// Asserts that zero weights will be shuffled.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_weighted_shuffle_zero_weights() {
|
fn test_weighted_shuffle_zero_weights() {
|
||||||
let weights = vec![0u64; 5];
|
let weights = vec![0u64; 5];
|
||||||
let mut rng = rand::thread_rng();
|
let seed = [37u8; 32];
|
||||||
let shuffle = WeightedShuffle::new(&mut rng, &weights);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
assert!(shuffle.unwrap().next().is_none());
|
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||||
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
|
||||||
|
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Asserts that each index is selected proportional to its weight.
|
// Asserts that each index is selected proportional to its weight.
|
||||||
@ -352,13 +384,13 @@ mod tests {
|
|||||||
counts[shuffle.next().unwrap()] += 1;
|
counts[shuffle.next().unwrap()] += 1;
|
||||||
let _ = shuffle.count(); // consume the rest.
|
let _ = shuffle.count(); // consume the rest.
|
||||||
}
|
}
|
||||||
assert_eq!(counts, [101, 0, 90113, 0, 0, 891, 8895, 0]);
|
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_weighted_shuffle_hard_coded() {
|
fn test_weighted_shuffle_hard_coded() {
|
||||||
let weights = [
|
let weights = [
|
||||||
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72,
|
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
|
||||||
];
|
];
|
||||||
let cumulative_weights: Vec<_> = weights
|
let cumulative_weights: Vec<_> = weights
|
||||||
.iter()
|
.iter()
|
||||||
@ -372,7 +404,7 @@ mod tests {
|
|||||||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
shuffle,
|
shuffle,
|
||||||
[2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8]
|
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
|
||||||
);
|
);
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -384,12 +416,12 @@ mod tests {
|
|||||||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
shuffle,
|
shuffle,
|
||||||
[17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12]
|
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
|
||||||
);
|
);
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
weighted_sample_single(&mut rng, &cumulative_weights),
|
weighted_sample_single(&mut rng, &cumulative_weights),
|
||||||
Some(17),
|
Some(19),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user