caches WeightedShuffle struct in ClusterNodes (#22877)

Instead of reconstructing WeightedShuffle struct for each shred
broadcast or retransmit, we can use the same struct with minimal
mutations.
This commit is contained in:
behzad nouri
2022-02-02 15:12:26 +00:00
committed by GitHub
parent 2fda90e414
commit e3b137066d
2 changed files with 107 additions and 142 deletions

View File

@ -28,6 +28,7 @@ pub enum WeightedShuffleError<T> {
/// weight.
/// - Zero weighted indices are shuffled and appear only at the end, after
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices.
@ -118,6 +119,36 @@ where
debug_assert!(hi.1 > val);
(hi.0, hi.1 - lo.1)
}
pub fn remove_index(&mut self, index: usize) {
let zero = <T as Default>::default();
let weight = self.cumsum(index + 1) - self.cumsum(index);
if weight != zero {
self.remove(index + 1, weight);
} else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) {
self.zeros.remove(index);
}
}
}
impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
{
// Equivalent to weighted_shuffle.shuffle(&mut rng).next()
pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
let zero = <T as Default>::default();
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
let (index, _weight) = WeightedShuffle::search(self, sample);
return Some(index - 1);
}
if self.zeros.is_empty() {
return None;
}
let index = <usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
self.zeros.get(index).copied()
}
}
impl<'a, T: 'a> WeightedShuffle<T>
@ -143,39 +174,6 @@ where
}
}
// Equivalent to WeightedShuffle(rng, weights).unwrap().next().
pub fn weighted_sample_single<R: Rng, T>(rng: &mut R, cumulative_weights: &[T]) -> Option<usize>
where
T: Copy + Default + PartialOrd + SampleUniform,
{
let zero = <T as Default>::default();
let high = cumulative_weights.last().copied().unwrap_or_default();
if high == zero {
if cumulative_weights.is_empty() {
return None;
}
let index =
<usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng);
return Some(index);
}
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
let mut lo = 0usize;
let mut hi = cumulative_weights.len() - 1;
while lo + 1 < hi {
let k = lo + (hi - lo) / 2;
if cumulative_weights[k] <= sample {
lo = k;
} else {
hi = k;
}
}
if cumulative_weights[lo] > sample {
Some(lo)
} else {
Some(hi)
}
}
/// Returns a list of indexes shuffled based on the input weights
/// Note - The sum of all weights must not exceed `u64::MAX`
pub fn weighted_shuffle<T, B, F>(weights: F, seed: [u8; 32]) -> Vec<usize>
@ -346,8 +344,8 @@ mod tests {
let weights = Vec::<u64>::new();
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new(&weights).unwrap();
assert!(shuffle.shuffle(&mut rng).next().is_none());
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
assert!(shuffle.clone().shuffle(&mut rng).next().is_none());
assert!(shuffle.first(&mut rng).is_none());
}
// Asserts that zero weights will be shuffled.
@ -357,9 +355,12 @@ mod tests {
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
assert_eq!(
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[1, 4, 2, 3, 0]
);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(shuffle.first(&mut rng), Some(1));
}
// Asserts that each index is selected proportional to its weight.
@ -376,6 +377,17 @@ mod tests {
let _ = shuffle.count(); // consume the rest.
}
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
let mut counts = [0; 8];
for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new(&weights).unwrap();
shuffle.remove_index(5);
shuffle.remove_index(3);
shuffle.remove_index(1);
let mut shuffle = shuffle.shuffle(&mut rng);
counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest.
}
assert_eq!(counts, [97, 0, 90862, 0, 0, 0, 9041, 0]);
}
#[test]
@ -383,52 +395,54 @@ mod tests {
let weights = [
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
];
let cumulative_weights: Vec<_> = weights
.iter()
.scan(0, |acc, w| {
*acc += w;
Some(*acc)
})
.collect();
let seed = [48u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
let mut shuffle = WeightedShuffle::new(&weights).unwrap();
assert_eq!(
shuffle,
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(shuffle.first(&mut rng), Some(2));
let mut rng = ChaChaRng::from_seed(seed);
shuffle.remove_index(11);
shuffle.remove_index(3);
shuffle.remove_index(15);
shuffle.remove_index(0);
assert_eq!(
weighted_sample_single(&mut rng, &cumulative_weights),
Some(2),
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[4, 6, 1, 12, 19, 14, 17, 20, 2, 9, 10, 8, 7, 18, 13, 5, 16]
);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(shuffle.first(&mut rng), Some(4));
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
let mut shuffle = WeightedShuffle::new(&weights).unwrap();
assert_eq!(
shuffle,
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(shuffle.first(&mut rng), Some(19));
shuffle.remove_index(16);
shuffle.remove_index(8);
shuffle.remove_index(20);
shuffle.remove_index(5);
shuffle.remove_index(19);
shuffle.remove_index(4);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(
weighted_sample_single(&mut rng, &cumulative_weights),
Some(19),
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[17, 2, 9, 14, 6, 10, 12, 1, 15, 13, 7, 0, 18, 3, 11]
);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(shuffle.first(&mut rng), Some(17));
}
#[test]
fn test_weighted_shuffle_match_slow() {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect();
let cumulative_weights: Vec<_> = weights
.iter()
.scan(0, |acc, w| {
*acc += w;
Some(*acc)
})
.collect();
for _ in 0..10 {
let mut seed = [0u8; 32];
rng.fill(&mut seed[..]);
@ -439,10 +453,8 @@ mod tests {
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
assert_eq!(shuffle, shuffle_slow);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(
weighted_sample_single(&mut rng, &cumulative_weights),
Some(shuffle[0]),
);
let shuffle = WeightedShuffle::new(&weights).unwrap();
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
}
}
}