removes Rng field from WeightedShuffle struct (#22850) (#22868)

(cherry picked from commit 45e09664b8fa09bab83f23f7c725c66ce645d931)

Co-authored-by: behzad nouri <behzadnouri@gmail.com>
This commit is contained in:
mergify[bot] 2022-02-01 17:33:52 +00:00 committed by GitHub
parent 3aa3cd8852
commit c715bc93cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 44 additions and 46 deletions

View File

@ -411,8 +411,9 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> { fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
// Nodes are sorted by (stake, pubkey) in descending order. // Nodes are sorted by (stake, pubkey) in descending order.
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect(); let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
WeightedShuffle::new(rng, &stakes) WeightedShuffle::new(&stakes)
.unwrap() .unwrap()
.shuffle(rng)
.map(|i| nodes[i]) .map(|i| nodes[i])
.collect() .collect()
} }

View File

@ -32,8 +32,9 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) {
let weights = make_weights(&mut rng); let weights = make_weights(&mut rng);
bencher.iter(|| { bencher.iter(|| {
rng.fill(&mut seed[..]); rng.fill(&mut seed[..]);
WeightedShuffle::new(&mut ChaChaRng::from_seed(seed), &weights) let shuffle = WeightedShuffle::new(&weights).unwrap();
.unwrap() shuffle
.shuffle(&mut ChaChaRng::from_seed(seed))
.collect::<Vec<_>>() .collect::<Vec<_>>()
}); });
} }

View File

@ -2010,7 +2010,7 @@ impl ClusterInfo {
return packet_batch; return packet_batch;
} }
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new(&mut rng, &scores).unwrap(); let shuffle = WeightedShuffle::new(&scores).unwrap().shuffle(&mut rng);
let mut total_bytes = 0; let mut total_bytes = 0;
let mut sent = 0; let mut sent = 0;
for (addr, response) in shuffle.map(|i| &responses[i]) { for (addr, response) in shuffle.map(|i| &responses[i]) {

View File

@ -246,8 +246,9 @@ impl CrdsGossipPull {
return Err(CrdsGossipError::NoPeers); return Err(CrdsGossipError::NoPeers);
} }
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let mut peers = WeightedShuffle::new(&mut rng, &weights) let mut peers = WeightedShuffle::new(&weights)
.unwrap() .unwrap()
.shuffle(&mut rng)
.map(|i| peers[i]); .map(|i| peers[i]);
let peer = { let peer = {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();

View File

@ -169,8 +169,9 @@ impl CrdsGossipPush {
.filter(|(_, stake)| *stake > 0) .filter(|(_, stake)| *stake > 0)
.collect(); .collect();
let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect(); let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect();
WeightedShuffle::new(&mut rng, &weights) WeightedShuffle::new(&weights)
.unwrap() .unwrap()
.shuffle(&mut rng)
.map(move |i| peers[i]) .map(move |i| peers[i])
}; };
let mut keep = HashSet::new(); let mut keep = HashSet::new();
@ -369,7 +370,7 @@ impl CrdsGossipPush {
return; return;
} }
let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size); let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size);
let shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); let shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng);
let mut active_set = self.active_set.write().unwrap(); let mut active_set = self.active_set.write().unwrap();
let need = Self::compute_need(self.num_active, active_set.len(), ratio); let need = Self::compute_need(self.num_active, active_set.len(), ratio);
for peer in shuffle.map(|i| peers[i]) { for peer in shuffle.map(|i| peers[i]) {

View File

@ -28,25 +28,24 @@ pub enum WeightedShuffleError<T> {
/// weight. /// weight.
/// - Zero weighted indices are shuffled and appear only at the end, after /// - Zero weighted indices are shuffled and appear only at the end, after
/// non-zero weighted indices. /// non-zero weighted indices.
pub struct WeightedShuffle<'a, R, T> { pub struct WeightedShuffle<T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree. arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices. sum: T, // Current sum of weights, excluding already selected indices.
zeros: Vec<usize>, // Indices of zero weighted entries. zeros: Vec<usize>, // Indices of zero weighted entries.
rng: &'a mut R, // Random number generator.
} }
// The implementation uses binary indexed tree: // The implementation uses binary indexed tree:
// https://en.wikipedia.org/wiki/Fenwick_tree // https://en.wikipedia.org/wiki/Fenwick_tree
// to maintain cumulative sum of weights excluding already selected indices // to maintain cumulative sum of weights excluding already selected indices
// over self.arr. // over self.arr.
impl<'a, R: Rng, T> WeightedShuffle<'a, R, T> impl<T> WeightedShuffle<T>
where where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
{ {
/// Returns error if: /// Returns error if:
/// - any of the weights are negative. /// - any of the weights are negative.
/// - sum of weights overflows. /// - sum of weights overflows.
pub fn new(rng: &'a mut R, weights: &[T]) -> Result<Self, WeightedShuffleError<T>> { pub fn new(weights: &[T]) -> Result<Self, WeightedShuffleError<T>> {
let size = weights.len() + 1; let size = weights.len() + 1;
let zero = <T as Default>::default(); let zero = <T as Default>::default();
let mut arr = vec![zero; size]; let mut arr = vec![zero; size];
@ -70,16 +69,11 @@ where
k += k & k.wrapping_neg(); k += k & k.wrapping_neg();
} }
} }
Ok(Self { Ok(Self { arr, sum, zeros })
arr,
sum,
rng,
zeros,
})
} }
} }
impl<'a, R, T> WeightedShuffle<'a, R, T> impl<T> WeightedShuffle<T>
where where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>, T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{ {
@ -126,30 +120,26 @@ where
} }
} }
impl<'a, R: Rng, T> Iterator for WeightedShuffle<'a, R, T> impl<'a, T: 'a> WeightedShuffle<T>
where where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>, T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
{ {
type Item = usize; pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
std::iter::from_fn(move || {
fn next(&mut self) -> Option<Self::Item> { let zero = <T as Default>::default();
let zero = <T as Default>::default(); if self.sum > zero {
if self.sum > zero { let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
let sample = let (index, weight) = WeightedShuffle::search(&self, sample);
<T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng); self.remove(index, weight);
let (index, weight) = WeightedShuffle::search(self, sample); return Some(index - 1);
self.remove(index, weight); }
return Some(index - 1); if self.zeros.is_empty() {
} return None;
if self.zeros.is_empty() { }
return None; let index =
} <usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
let index = <usize as SampleUniform>::Sampler::sample_single( Some(self.zeros.swap_remove(index))
0usize, })
self.zeros.len(),
&mut self.rng,
);
Some(self.zeros.swap_remove(index))
} }
} }
@ -355,8 +345,8 @@ mod tests {
fn test_weighted_shuffle_empty_weights() { fn test_weighted_shuffle_empty_weights() {
let weights = Vec::<u64>::new(); let weights = Vec::<u64>::new();
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new(&mut rng, &weights); let shuffle = WeightedShuffle::new(&weights).unwrap();
assert!(shuffle.unwrap().next().is_none()); assert!(shuffle.shuffle(&mut rng).next().is_none());
assert_eq!(weighted_sample_single(&mut rng, &weights), None); assert_eq!(weighted_sample_single(&mut rng, &weights), None);
} }
@ -366,7 +356,8 @@ mod tests {
let weights = vec![0u64; 5]; let weights = vec![0u64; 5];
let seed = [37u8; 32]; let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed); let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
assert_eq!(shuffle, [1, 4, 2, 3, 0]); assert_eq!(shuffle, [1, 4, 2, 3, 0]);
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1)); assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
} }
@ -380,7 +371,7 @@ mod tests {
let weights = [1, 0, 1000, 0, 0, 10, 100, 0]; let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
let mut counts = [0; 8]; let mut counts = [0; 8];
for _ in 0..100000 { for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); let mut shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng);
counts[shuffle.next().unwrap()] += 1; counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest. let _ = shuffle.count(); // consume the rest.
} }
@ -401,7 +392,8 @@ mod tests {
.collect(); .collect();
let seed = [48u8; 32]; let seed = [48u8; 32];
let mut rng = ChaChaRng::from_seed(seed); let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
assert_eq!( assert_eq!(
shuffle, shuffle,
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
@ -413,7 +405,8 @@ mod tests {
); );
let seed = [37u8; 32]; let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed); let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
assert_eq!( assert_eq!(
shuffle, shuffle,
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
@ -440,7 +433,8 @@ mod tests {
let mut seed = [0u8; 32]; let mut seed = [0u8; 32];
rng.fill(&mut seed[..]); rng.fill(&mut seed[..]);
let mut rng = ChaChaRng::from_seed(seed); let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
let mut rng = ChaChaRng::from_seed(seed); let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
assert_eq!(shuffle, shuffle_slow); assert_eq!(shuffle, shuffle_slow);