From d668a7694f47bd2a011fb0bab805449ee7afff89 Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Wed, 7 Jul 2021 16:07:41 +0000 Subject: [PATCH] implements an unbiased weighted shuffle using binary indexed tree (#18343) (#18485) Current implementation of weighted_shuffle: https://github.com/solana-labs/solana/blob/b08f8bd1b/gossip/src/weighted_shuffle.rs#L11-L37 uses a heuristic which results in biased samples. For example, if the weights are [1, 10, 100], then the 3rd index should come first 100 times more often than the 1st index. However, weighted_shuffle is picking the 3rd index 200+ times more often than the 1st index, showing a disproportional bias in favor of higher weights. This commit implements weighted shuffle using binary indexed tree to maintain cumulative sum of weights while sampling. The resulting samples are demonstrably unbiased and precisely proportional to the weights. Additionally the iterator interface allows to skip computations when not all indices are processed. Of the use cases of weighted_shuffle, changing turbine code requires feature-gating to keep the cluster in sync. That is not updated in this commit, but can be done together with future updates to turbine. (cherry picked from commit dba42c57b458f8e620897ac05011ea5f11b84740) Co-authored-by: behzad nouri --- gossip/benches/weighted_shuffle.rs | 39 ++++++ gossip/src/cluster_info.rs | 9 +- gossip/src/crds_gossip_pull.rs | 19 ++- gossip/src/crds_gossip_push.rs | 21 ++- gossip/src/weighted_shuffle.rs | 209 ++++++++++++++++++++++++++++- 5 files changed, 266 insertions(+), 31 deletions(-) create mode 100644 gossip/benches/weighted_shuffle.rs diff --git a/gossip/benches/weighted_shuffle.rs b/gossip/benches/weighted_shuffle.rs new file mode 100644 index 0000000000..37097f3e62 --- /dev/null +++ b/gossip/benches/weighted_shuffle.rs @@ -0,0 +1,39 @@ +#![feature(test)] + +extern crate test; + +use { + rand::{Rng, SeedableRng}, + rand_chacha::ChaChaRng, + solana_gossip::weighted_shuffle::{weighted_shuffle, WeightedShuffle}, + std::iter::repeat_with, + test::Bencher, +}; + +fn make_weights(rng: &mut R) -> Vec { + repeat_with(|| rng.gen_range(1, 100)).take(1000).collect() +} + +#[bench] +fn bench_weighted_shuffle_old(bencher: &mut Bencher) { + let mut seed = [0u8; 32]; + let mut rng = rand::thread_rng(); + let weights = make_weights(&mut rng); + bencher.iter(|| { + rng.fill(&mut seed[..]); + weighted_shuffle(&weights, seed); + }); +} + +#[bench] +fn bench_weighted_shuffle_new(bencher: &mut Bencher) { + let mut seed = [0u8; 32]; + let mut rng = rand::thread_rng(); + let weights = make_weights(&mut rng); + bencher.iter(|| { + rng.fill(&mut seed[..]); + WeightedShuffle::new(&mut ChaChaRng::from_seed(seed), &weights) + .unwrap() + .collect::>() + }); +} diff --git a/gossip/src/cluster_info.rs b/gossip/src/cluster_info.rs index cf6eda727e..98534d474a 100644 --- a/gossip/src/cluster_info.rs +++ b/gossip/src/cluster_info.rs @@ -29,7 +29,7 @@ use { gossip_error::GossipError, ping_pong::{self, PingCache, Pong}, socketaddr, socketaddr_any, - weighted_shuffle::weighted_shuffle, + weighted_shuffle::{weighted_shuffle, WeightedShuffle}, }, bincode::{serialize, serialized_size}, itertools::Itertools, @@ -2017,11 +2017,8 @@ impl ClusterInfo { if responses.is_empty() { return packets; } - let shuffle = { - let mut seed = [0; 32]; - rand::thread_rng().fill(&mut seed[..]); - weighted_shuffle(&scores, seed).into_iter() - }; + let mut rng = rand::thread_rng(); + let shuffle = WeightedShuffle::new(&mut rng, &scores).unwrap(); let mut total_bytes = 0; let mut sent = 0; for (addr, response) in shuffle.map(|i| &responses[i]) { diff --git a/gossip/src/crds_gossip_pull.rs b/gossip/src/crds_gossip_pull.rs index 225ca5948a..d2d9266ab9 100644 --- a/gossip/src/crds_gossip_pull.rs +++ b/gossip/src/crds_gossip_pull.rs @@ -18,7 +18,7 @@ use { crds_gossip_error::CrdsGossipError, crds_value::CrdsValue, ping_pong::PingCache, - weighted_shuffle::weighted_shuffle, + weighted_shuffle::WeightedShuffle, }, itertools::Itertools, lru::LruCache, @@ -235,13 +235,10 @@ impl CrdsGossipPull { if peers.is_empty() { return Err(CrdsGossipError::NoPeers); } - let mut peers = { - let mut rng = rand::thread_rng(); - let mut seed = [0u8; 32]; - rng.fill(&mut seed[..]); - let index = weighted_shuffle(&weights, seed); - index.into_iter().map(|i| peers[i]) - }; + let mut rng = rand::thread_rng(); + let mut peers = WeightedShuffle::new(&mut rng, &weights) + .unwrap() + .map(|i| peers[i]); let peer = { let mut rng = rand::thread_rng(); let mut ping_cache = ping_cache.lock().unwrap(); @@ -273,7 +270,7 @@ impl CrdsGossipPull { now: u64, gossip_validators: Option<&HashSet>, stakes: &HashMap, - ) -> Vec<(f32, &'a ContactInfo)> { + ) -> Vec<(u64, &'a ContactInfo)> { let mut rng = rand::thread_rng(); let active_cutoff = now.saturating_sub(PULL_ACTIVE_TIMEOUT_MS); crds.get_nodes() @@ -307,7 +304,9 @@ impl CrdsGossipPull { let since = (now.saturating_sub(req_time).min(3600 * 1000) / 1024) as u32; let stake = get_stake(&item.id, stakes); let weight = get_weight(max_weight, since, stake); - (weight, item) + // Weights are bounded by max_weight defined above. + // So this type-cast should be safe. + ((weight * 100.0) as u64, item) }) .collect() } diff --git a/gossip/src/crds_gossip_push.rs b/gossip/src/crds_gossip_push.rs index 6e7319a747..ca836546b6 100644 --- a/gossip/src/crds_gossip_push.rs +++ b/gossip/src/crds_gossip_push.rs @@ -16,7 +16,7 @@ use { crds_gossip::{get_stake, get_weight}, crds_gossip_error::CrdsGossipError, crds_value::CrdsValue, - weighted_shuffle::weighted_shuffle, + weighted_shuffle::WeightedShuffle, }, bincode::serialized_size, indexmap::map::IndexMap, @@ -119,6 +119,7 @@ impl CrdsGossipPush { if peer_stake_total < prune_stake_threshold { return Vec::new(); } + let mut rng = rand::thread_rng(); let shuffled_staked_peers = { let peers: Vec<_> = peers .iter() @@ -126,11 +127,9 @@ impl CrdsGossipPush { .filter_map(|(peer, _)| Some((*peer, *stakes.get(peer)?))) .filter(|(_, stake)| *stake > 0) .collect(); - let mut seed = [0; 32]; - rand::thread_rng().fill(&mut seed[..]); let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect(); - weighted_shuffle(&weights, seed) - .into_iter() + WeightedShuffle::new(&mut rng, &weights) + .unwrap() .map(move |i| peers[i]) }; let mut keep = HashSet::new(); @@ -282,11 +281,7 @@ impl CrdsGossipPush { return; } let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size); - let shuffle = { - let mut seed = [0; 32]; - rng.fill(&mut seed[..]); - weighted_shuffle(&weights, seed).into_iter() - }; + let shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); for peer in shuffle.map(|i| peers[i].id) { if new_items.len() >= need { break; @@ -320,7 +315,7 @@ impl CrdsGossipPush { self_shred_version: u16, stakes: &HashMap, gossip_validators: Option<&HashSet>, - ) -> Vec<(f32, &'a ContactInfo)> { + ) -> Vec<(u64, &'a ContactInfo)> { let now = timestamp(); let mut rng = rand::thread_rng(); let max_weight = u16::MAX as f32 - 1.0; @@ -356,7 +351,9 @@ impl CrdsGossipPush { let since = (now.saturating_sub(last_pushed_to).min(3600 * 1000) / 1024) as u32; let stake = get_stake(&info.id, stakes); let weight = get_weight(max_weight, since, stake); - (weight, info) + // Weights are bounded by max_weight defined above. + // So this type-cast should be safe. + ((weight * 100.0) as u64, info) }) .collect() } diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 62e2d18eb6..63ba440f30 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -2,12 +2,138 @@ use { itertools::Itertools, - num_traits::{FromPrimitive, ToPrimitive}, - rand::{Rng, SeedableRng}, + num_traits::{CheckedAdd, FromPrimitive, ToPrimitive}, + rand::{ + distributions::uniform::{SampleUniform, UniformSampler}, + Rng, SeedableRng, + }, rand_chacha::ChaChaRng, - std::{iter, ops::Div}, + std::{ + iter, + ops::{AddAssign, Div, Sub, SubAssign}, + }, }; +#[derive(Debug)] +pub enum WeightedShuffleError { + NegativeWeight(T), + SumOverflow, +} + +/// Implements an iterator where indices are shuffled according to their +/// weights: +/// - Returned indices are unique in the range [0, weights.len()). +/// - Higher weighted indices tend to appear earlier proportional to their +/// weight. +/// - Zero weighted indices are excluded. Therefore the iterator may have +/// count less than weights.len(). +pub struct WeightedShuffle<'a, R, T> { + arr: Vec, // Underlying array implementing binary indexed tree. + sum: T, // Current sum of weights, excluding already selected indices. + rng: &'a mut R, // Random number generator. +} + +// The implementation uses binary indexed tree: +// https://en.wikipedia.org/wiki/Fenwick_tree +// to maintain cumulative sum of weights excluding already selected indices +// over self.arr. +impl<'a, R: Rng, T> WeightedShuffle<'a, R, T> +where + T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, +{ + /// Returns error if: + /// - any of the weights are negative. + /// - sum of weights overflows. + pub fn new(rng: &'a mut R, weights: &[T]) -> Result> { + let size = weights.len() + 1; + let zero = ::default(); + let mut arr = vec![zero; size]; + let mut sum = zero; + for (mut k, &weight) in (1usize..).zip(weights) { + #[allow(clippy::neg_cmp_op_on_partial_ord)] + // weight < zero does not work for NaNs. + if !(weight >= zero) { + return Err(WeightedShuffleError::NegativeWeight(weight)); + } + sum = sum + .checked_add(&weight) + .ok_or(WeightedShuffleError::SumOverflow)?; + while k < size { + arr[k] += weight; + k += k & k.wrapping_neg(); + } + } + Ok(Self { arr, sum, rng }) + } +} + +impl<'a, R, T> WeightedShuffle<'a, R, T> +where + T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, +{ + // Returns cumulative sum of current weights upto index k (inclusive). + fn cumsum(&self, mut k: usize) -> T { + let mut out = ::default(); + while k != 0 { + out += self.arr[k]; + k ^= k & k.wrapping_neg(); + } + out + } + + // Removes given weight at index k. + fn remove(&mut self, mut k: usize, weight: T) { + self.sum -= weight; + let size = self.arr.len(); + while k < size { + self.arr[k] -= weight; + k += k & k.wrapping_neg(); + } + } + + // Returns smallest index such that self.cumsum(k) > val, + // along with its respective weight. + fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { + let zero = ::default(); + debug_assert!(val >= zero); + debug_assert!(val < self.sum); + let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); + let mut hi = (self.arr.len() - 1, self.sum); + while lo.0 + 1 < hi.0 { + let k = lo.0 + (hi.0 - lo.0) / 2; + let sum = self.cumsum(k); + if sum <= val { + lo = (k, sum); + } else { + hi = (k, sum); + } + } + debug_assert!(lo.1 <= val); + debug_assert!(hi.1 > val); + (hi.0, hi.1 - lo.1) + } +} + +impl<'a, R: Rng, T> Iterator for WeightedShuffle<'a, R, T> +where + T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, +{ + type Item = usize; + + fn next(&mut self) -> Option { + let zero = ::default(); + #[allow(clippy::neg_cmp_op_on_partial_ord)] + // self.sum <= zero does not work for NaNs. + if !(self.sum > zero) { + return None; + } + let sample = ::Sampler::sample_single(zero, self.sum, &mut self.rng); + let (index, weight) = WeightedShuffle::search(self, sample); + self.remove(index, weight); + Some(index - 1) + } +} + /// Returns a list of indexes shuffled based on the input weights /// Note - The sum of all weights must not exceed `u64::MAX` pub fn weighted_shuffle(weights: &[T], seed: [u8; 32]) -> Vec @@ -67,6 +193,31 @@ pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> us #[cfg(test)] mod tests { use super::*; + use std::{convert::TryInto, iter::repeat_with}; + + fn weighted_shuffle_slow(rng: &mut R, mut weights: Vec) -> Vec + where + R: Rng, + { + let mut shuffle = Vec::with_capacity(weights.len()); + loop { + let high: u64 = weights.iter().sum(); + if high == 0 { + break shuffle; + } + let sample = rng.gen_range(0, high); + let index = weights + .iter() + .scan(0, |acc, &w| { + *acc += w; + Some(*acc) + }) + .position(|acc| sample < acc) + .unwrap(); + shuffle.push(index); + weights[index] = 0; + } + } #[test] fn test_weighted_shuffle_iterator() { @@ -133,4 +284,56 @@ mod tests { let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]); assert_eq!(best_index, 2); } + + // Asserts that each index is selected proportional to its weight. + #[test] + fn test_weighted_shuffle_sanity() { + let seed: Vec<_> = (1..).step_by(3).take(32).collect(); + let seed: [u8; 32] = seed.try_into().unwrap(); + let mut rng = ChaChaRng::from_seed(seed); + let weights = [1, 1000, 10, 100]; + let mut counts = [0; 4]; + for _ in 0..100000 { + let mut shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); + counts[shuffle.next().unwrap()] += 1; + let _ = shuffle.count(); // consume the rest. + } + assert_eq!(counts, [101, 90113, 891, 8895]); + } + + #[test] + fn test_weighted_shuffle_hard_coded() { + let weights = [ + 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72, + ]; + let seed = [48u8; 32]; + let mut rng = ChaChaRng::from_seed(seed); + let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + assert_eq!( + shuffle, + [2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8] + ); + let seed = [37u8; 32]; + let mut rng = ChaChaRng::from_seed(seed); + let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + assert_eq!( + shuffle, + [17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12] + ); + } + + #[test] + fn test_weighted_shuffle_match_slow() { + let mut rng = rand::thread_rng(); + let weights: Vec = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect(); + for _ in 0..10 { + let mut seed = [0u8; 32]; + rng.fill(&mut seed[..]); + let mut rng = ChaChaRng::from_seed(seed); + let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + let mut rng = ChaChaRng::from_seed(seed); + let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); + assert_eq!(shuffle, shuffle_slow,); + } + } }