diff --git a/gossip/benches/weighted_shuffle.rs b/gossip/benches/weighted_shuffle.rs new file mode 100644 index 0000000000..37097f3e62 --- /dev/null +++ b/gossip/benches/weighted_shuffle.rs @@ -0,0 +1,39 @@ +#![feature(test)] + +extern crate test; + +use { + rand::{Rng, SeedableRng}, + rand_chacha::ChaChaRng, + solana_gossip::weighted_shuffle::{weighted_shuffle, WeightedShuffle}, + std::iter::repeat_with, + test::Bencher, +}; + +fn make_weights(rng: &mut R) -> Vec { + repeat_with(|| rng.gen_range(1, 100)).take(1000).collect() +} + +#[bench] +fn bench_weighted_shuffle_old(bencher: &mut Bencher) { + let mut seed = [0u8; 32]; + let mut rng = rand::thread_rng(); + let weights = make_weights(&mut rng); + bencher.iter(|| { + rng.fill(&mut seed[..]); + weighted_shuffle(&weights, seed); + }); +} + +#[bench] +fn bench_weighted_shuffle_new(bencher: &mut Bencher) { + let mut seed = [0u8; 32]; + let mut rng = rand::thread_rng(); + let weights = make_weights(&mut rng); + bencher.iter(|| { + rng.fill(&mut seed[..]); + WeightedShuffle::new(&mut ChaChaRng::from_seed(seed), &weights) + .unwrap() + .collect::>() + }); +} diff --git a/gossip/src/cluster_info.rs b/gossip/src/cluster_info.rs index a748eaf69f..c160f070c0 100644 --- a/gossip/src/cluster_info.rs +++ b/gossip/src/cluster_info.rs @@ -29,7 +29,7 @@ use { gossip_error::GossipError, ping_pong::{self, PingCache, Pong}, socketaddr, socketaddr_any, - weighted_shuffle::weighted_shuffle, + weighted_shuffle::{weighted_shuffle, WeightedShuffle}, }, bincode::{serialize, serialized_size}, itertools::Itertools, @@ -2043,11 +2043,8 @@ impl ClusterInfo { if responses.is_empty() { return packets; } - let shuffle = { - let mut seed = [0; 32]; - rand::thread_rng().fill(&mut seed[..]); - weighted_shuffle(&scores, seed).into_iter() - }; + let mut rng = rand::thread_rng(); + let shuffle = WeightedShuffle::new(&mut rng, &scores).unwrap(); let mut total_bytes = 0; let mut sent = 0; for (addr, response) in shuffle.map(|i| &responses[i]) { diff --git a/gossip/src/crds_gossip_pull.rs b/gossip/src/crds_gossip_pull.rs index 225ca5948a..d2d9266ab9 100644 --- a/gossip/src/crds_gossip_pull.rs +++ b/gossip/src/crds_gossip_pull.rs @@ -18,7 +18,7 @@ use { crds_gossip_error::CrdsGossipError, crds_value::CrdsValue, ping_pong::PingCache, - weighted_shuffle::weighted_shuffle, + weighted_shuffle::WeightedShuffle, }, itertools::Itertools, lru::LruCache, @@ -235,13 +235,10 @@ impl CrdsGossipPull { if peers.is_empty() { return Err(CrdsGossipError::NoPeers); } - let mut peers = { - let mut rng = rand::thread_rng(); - let mut seed = [0u8; 32]; - rng.fill(&mut seed[..]); - let index = weighted_shuffle(&weights, seed); - index.into_iter().map(|i| peers[i]) - }; + let mut rng = rand::thread_rng(); + let mut peers = WeightedShuffle::new(&mut rng, &weights) + .unwrap() + .map(|i| peers[i]); let peer = { let mut rng = rand::thread_rng(); let mut ping_cache = ping_cache.lock().unwrap(); @@ -273,7 +270,7 @@ impl CrdsGossipPull { now: u64, gossip_validators: Option<&HashSet>, stakes: &HashMap, - ) -> Vec<(f32, &'a ContactInfo)> { + ) -> Vec<(u64, &'a ContactInfo)> { let mut rng = rand::thread_rng(); let active_cutoff = now.saturating_sub(PULL_ACTIVE_TIMEOUT_MS); crds.get_nodes() @@ -307,7 +304,9 @@ impl CrdsGossipPull { let since = (now.saturating_sub(req_time).min(3600 * 1000) / 1024) as u32; let stake = get_stake(&item.id, stakes); let weight = get_weight(max_weight, since, stake); - (weight, item) + // Weights are bounded by max_weight defined above. + // So this type-cast should be safe. + ((weight * 100.0) as u64, item) }) .collect() } diff --git a/gossip/src/crds_gossip_push.rs b/gossip/src/crds_gossip_push.rs index 6e7319a747..ca836546b6 100644 --- a/gossip/src/crds_gossip_push.rs +++ b/gossip/src/crds_gossip_push.rs @@ -16,7 +16,7 @@ use { crds_gossip::{get_stake, get_weight}, crds_gossip_error::CrdsGossipError, crds_value::CrdsValue, - weighted_shuffle::weighted_shuffle, + weighted_shuffle::WeightedShuffle, }, bincode::serialized_size, indexmap::map::IndexMap, @@ -119,6 +119,7 @@ impl CrdsGossipPush { if peer_stake_total < prune_stake_threshold { return Vec::new(); } + let mut rng = rand::thread_rng(); let shuffled_staked_peers = { let peers: Vec<_> = peers .iter() @@ -126,11 +127,9 @@ impl CrdsGossipPush { .filter_map(|(peer, _)| Some((*peer, *stakes.get(peer)?))) .filter(|(_, stake)| *stake > 0) .collect(); - let mut seed = [0; 32]; - rand::thread_rng().fill(&mut seed[..]); let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect(); - weighted_shuffle(&weights, seed) - .into_iter() + WeightedShuffle::new(&mut rng, &weights) + .unwrap() .map(move |i| peers[i]) }; let mut keep = HashSet::new(); @@ -282,11 +281,7 @@ impl CrdsGossipPush { return; } let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size); - let shuffle = { - let mut seed = [0; 32]; - rng.fill(&mut seed[..]); - weighted_shuffle(&weights, seed).into_iter() - }; + let shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); for peer in shuffle.map(|i| peers[i].id) { if new_items.len() >= need { break; @@ -320,7 +315,7 @@ impl CrdsGossipPush { self_shred_version: u16, stakes: &HashMap, gossip_validators: Option<&HashSet>, - ) -> Vec<(f32, &'a ContactInfo)> { + ) -> Vec<(u64, &'a ContactInfo)> { let now = timestamp(); let mut rng = rand::thread_rng(); let max_weight = u16::MAX as f32 - 1.0; @@ -356,7 +351,9 @@ impl CrdsGossipPush { let since = (now.saturating_sub(last_pushed_to).min(3600 * 1000) / 1024) as u32; let stake = get_stake(&info.id, stakes); let weight = get_weight(max_weight, since, stake); - (weight, info) + // Weights are bounded by max_weight defined above. + // So this type-cast should be safe. + ((weight * 100.0) as u64, info) }) .collect() } diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 62e2d18eb6..63ba440f30 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -2,12 +2,138 @@ use { itertools::Itertools, - num_traits::{FromPrimitive, ToPrimitive}, - rand::{Rng, SeedableRng}, + num_traits::{CheckedAdd, FromPrimitive, ToPrimitive}, + rand::{ + distributions::uniform::{SampleUniform, UniformSampler}, + Rng, SeedableRng, + }, rand_chacha::ChaChaRng, - std::{iter, ops::Div}, + std::{ + iter, + ops::{AddAssign, Div, Sub, SubAssign}, + }, }; +#[derive(Debug)] +pub enum WeightedShuffleError { + NegativeWeight(T), + SumOverflow, +} + +/// Implements an iterator where indices are shuffled according to their +/// weights: +/// - Returned indices are unique in the range [0, weights.len()). +/// - Higher weighted indices tend to appear earlier proportional to their +/// weight. +/// - Zero weighted indices are excluded. Therefore the iterator may have +/// count less than weights.len(). +pub struct WeightedShuffle<'a, R, T> { + arr: Vec, // Underlying array implementing binary indexed tree. + sum: T, // Current sum of weights, excluding already selected indices. + rng: &'a mut R, // Random number generator. +} + +// The implementation uses binary indexed tree: +// https://en.wikipedia.org/wiki/Fenwick_tree +// to maintain cumulative sum of weights excluding already selected indices +// over self.arr. +impl<'a, R: Rng, T> WeightedShuffle<'a, R, T> +where + T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, +{ + /// Returns error if: + /// - any of the weights are negative. + /// - sum of weights overflows. + pub fn new(rng: &'a mut R, weights: &[T]) -> Result> { + let size = weights.len() + 1; + let zero = ::default(); + let mut arr = vec![zero; size]; + let mut sum = zero; + for (mut k, &weight) in (1usize..).zip(weights) { + #[allow(clippy::neg_cmp_op_on_partial_ord)] + // weight < zero does not work for NaNs. + if !(weight >= zero) { + return Err(WeightedShuffleError::NegativeWeight(weight)); + } + sum = sum + .checked_add(&weight) + .ok_or(WeightedShuffleError::SumOverflow)?; + while k < size { + arr[k] += weight; + k += k & k.wrapping_neg(); + } + } + Ok(Self { arr, sum, rng }) + } +} + +impl<'a, R, T> WeightedShuffle<'a, R, T> +where + T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, +{ + // Returns cumulative sum of current weights upto index k (inclusive). + fn cumsum(&self, mut k: usize) -> T { + let mut out = ::default(); + while k != 0 { + out += self.arr[k]; + k ^= k & k.wrapping_neg(); + } + out + } + + // Removes given weight at index k. + fn remove(&mut self, mut k: usize, weight: T) { + self.sum -= weight; + let size = self.arr.len(); + while k < size { + self.arr[k] -= weight; + k += k & k.wrapping_neg(); + } + } + + // Returns smallest index such that self.cumsum(k) > val, + // along with its respective weight. + fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { + let zero = ::default(); + debug_assert!(val >= zero); + debug_assert!(val < self.sum); + let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); + let mut hi = (self.arr.len() - 1, self.sum); + while lo.0 + 1 < hi.0 { + let k = lo.0 + (hi.0 - lo.0) / 2; + let sum = self.cumsum(k); + if sum <= val { + lo = (k, sum); + } else { + hi = (k, sum); + } + } + debug_assert!(lo.1 <= val); + debug_assert!(hi.1 > val); + (hi.0, hi.1 - lo.1) + } +} + +impl<'a, R: Rng, T> Iterator for WeightedShuffle<'a, R, T> +where + T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, +{ + type Item = usize; + + fn next(&mut self) -> Option { + let zero = ::default(); + #[allow(clippy::neg_cmp_op_on_partial_ord)] + // self.sum <= zero does not work for NaNs. + if !(self.sum > zero) { + return None; + } + let sample = ::Sampler::sample_single(zero, self.sum, &mut self.rng); + let (index, weight) = WeightedShuffle::search(self, sample); + self.remove(index, weight); + Some(index - 1) + } +} + /// Returns a list of indexes shuffled based on the input weights /// Note - The sum of all weights must not exceed `u64::MAX` pub fn weighted_shuffle(weights: &[T], seed: [u8; 32]) -> Vec @@ -67,6 +193,31 @@ pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> us #[cfg(test)] mod tests { use super::*; + use std::{convert::TryInto, iter::repeat_with}; + + fn weighted_shuffle_slow(rng: &mut R, mut weights: Vec) -> Vec + where + R: Rng, + { + let mut shuffle = Vec::with_capacity(weights.len()); + loop { + let high: u64 = weights.iter().sum(); + if high == 0 { + break shuffle; + } + let sample = rng.gen_range(0, high); + let index = weights + .iter() + .scan(0, |acc, &w| { + *acc += w; + Some(*acc) + }) + .position(|acc| sample < acc) + .unwrap(); + shuffle.push(index); + weights[index] = 0; + } + } #[test] fn test_weighted_shuffle_iterator() { @@ -133,4 +284,56 @@ mod tests { let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]); assert_eq!(best_index, 2); } + + // Asserts that each index is selected proportional to its weight. + #[test] + fn test_weighted_shuffle_sanity() { + let seed: Vec<_> = (1..).step_by(3).take(32).collect(); + let seed: [u8; 32] = seed.try_into().unwrap(); + let mut rng = ChaChaRng::from_seed(seed); + let weights = [1, 1000, 10, 100]; + let mut counts = [0; 4]; + for _ in 0..100000 { + let mut shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap(); + counts[shuffle.next().unwrap()] += 1; + let _ = shuffle.count(); // consume the rest. + } + assert_eq!(counts, [101, 90113, 891, 8895]); + } + + #[test] + fn test_weighted_shuffle_hard_coded() { + let weights = [ + 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72, + ]; + let seed = [48u8; 32]; + let mut rng = ChaChaRng::from_seed(seed); + let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + assert_eq!( + shuffle, + [2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8] + ); + let seed = [37u8; 32]; + let mut rng = ChaChaRng::from_seed(seed); + let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + assert_eq!( + shuffle, + [17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12] + ); + } + + #[test] + fn test_weighted_shuffle_match_slow() { + let mut rng = rand::thread_rng(); + let weights: Vec = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect(); + for _ in 0..10 { + let mut seed = [0u8; 32]; + rng.fill(&mut seed[..]); + let mut rng = ChaChaRng::from_seed(seed); + let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect(); + let mut rng = ChaChaRng::from_seed(seed); + let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); + assert_eq!(shuffle, shuffle_slow,); + } + } }