diff --git a/core/src/cluster_nodes.rs b/core/src/cluster_nodes.rs index a02bbdbfe1..8e712bf24a 100644 --- a/core/src/cluster_nodes.rs +++ b/core/src/cluster_nodes.rs @@ -2,15 +2,13 @@ use { crate::{broadcast_stage::BroadcastStage, retransmit_stage::RetransmitStage}, itertools::Itertools, lru::LruCache, - rand::{Rng, SeedableRng}, + rand::SeedableRng, rand_chacha::ChaChaRng, solana_gossip::{ cluster_info::{compute_retransmit_peers, ClusterInfo}, contact_info::ContactInfo, crds_gossip_pull::CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS, - weighted_shuffle::{ - weighted_best, weighted_sample_single, weighted_shuffle, WeightedShuffle, - }, + weighted_shuffle::{weighted_best, weighted_shuffle, WeightedShuffle}, }, solana_ledger::shred::Shred, solana_runtime::bank::Bank, @@ -51,9 +49,7 @@ pub struct ClusterNodes { // All staked nodes + other known tvu-peers + the node itself; // sorted by (stake, pubkey) in descending order. nodes: Vec, - // Cumulative stakes (excluding the node itself), used for sampling - // broadcast peers. - cumulative_weights: Vec, + weighted_shuffle: WeightedShuffle, // Weights and indices for sampling peers. weighted_{shuffle,best} expect // weights >= 1. For backward compatibility we use max(1, stake) for // weights and exclude nodes with no contact-info. @@ -133,7 +129,7 @@ impl ClusterNodes { return Vec::default(); } let mut rng = ChaChaRng::from_seed(shred_seed); - let index = match weighted_sample_single(&mut rng, &self.cumulative_weights) { + let index = match self.weighted_shuffle.first(&mut rng) { None => return Vec::default(), Some(index) => index, }; @@ -146,16 +142,16 @@ impl ClusterNodes { return vec![node.tvu]; } } - let nodes: Vec<_> = self - .nodes - .iter() - .filter(|node| node.pubkey() != self.pubkey) + let mut rng = ChaChaRng::from_seed(shred_seed); + let nodes: Vec<&Node> = self + .weighted_shuffle + .clone() + .shuffle(&mut rng) + .map(|index| &self.nodes[index]) .collect(); if nodes.is_empty() { return Vec::default(); } - let mut rng = ChaChaRng::from_seed(shred_seed); - let nodes = shuffle_nodes(&mut rng, &nodes); let (neighbors, children) = compute_retransmit_peers(fanout, 0, &nodes); neighbors[..1] .iter() @@ -235,18 +231,22 @@ impl ClusterNodes { if !enable_turbine_peers_shuffle_patch(shred.slot(), root_bank) { return self.get_retransmit_peers_compat(shred_seed, fanout, slot_leader); } + let mut weighted_shuffle = self.weighted_shuffle.clone(); // Exclude slot leader from list of nodes. - let nodes: Vec<_> = if slot_leader == self.pubkey { + if slot_leader == self.pubkey { error!("retransmit from slot leader: {}", slot_leader); - self.nodes.iter().collect() - } else { - self.nodes - .iter() - .filter(|node| node.pubkey() != slot_leader) - .collect() + } else if let Some(index) = self + .nodes + .iter() + .position(|node| node.pubkey() == slot_leader) + { + weighted_shuffle.remove_index(index); }; let mut rng = ChaChaRng::from_seed(shred_seed); - let nodes = shuffle_nodes(&mut rng, &nodes); + let nodes: Vec<_> = weighted_shuffle + .shuffle(&mut rng) + .map(|index| &self.nodes[index]) + .collect(); let self_index = nodes .iter() .position(|node| node.pubkey() == self.pubkey) @@ -299,30 +299,6 @@ impl ClusterNodes { } } -fn build_cumulative_weights(self_pubkey: Pubkey, nodes: &[Node]) -> Vec { - let cumulative_stakes: Vec<_> = nodes - .iter() - .scan(0, |acc, node| { - if node.pubkey() != self_pubkey { - *acc += node.stake; - } - Some(*acc) - }) - .collect(); - if cumulative_stakes.last() != Some(&0) { - return cumulative_stakes; - } - nodes - .iter() - .scan(0, |acc, node| { - if node.pubkey() != self_pubkey { - *acc += 1; - } - Some(*acc) - }) - .collect() -} - fn new_cluster_nodes( cluster_info: &ClusterInfo, stakes: &HashMap, @@ -330,11 +306,12 @@ fn new_cluster_nodes( let self_pubkey = cluster_info.id(); let nodes = get_nodes(cluster_info, stakes); let broadcast = TypeId::of::() == TypeId::of::(); - let cumulative_weights = if broadcast { - build_cumulative_weights(self_pubkey, &nodes) - } else { - Vec::default() - }; + let stakes: Vec = nodes.iter().map(|node| node.stake).collect(); + let mut weighted_shuffle = WeightedShuffle::new(&stakes).unwrap(); + if broadcast { + let index = nodes.iter().position(|node| node.pubkey() == self_pubkey); + weighted_shuffle.remove_index(index.unwrap()); + } // For backward compatibility: // * nodes which do not have contact-info are excluded. // * stakes are floored at 1. @@ -352,7 +329,7 @@ fn new_cluster_nodes( ClusterNodes { pubkey: self_pubkey, nodes, - cumulative_weights, + weighted_shuffle, index, _phantom: PhantomData::default(), } @@ -406,18 +383,6 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo } } -// Shuffles nodes w.r.t their stakes. -// Unstaked nodes will always appear at the very end. -fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> { - // Nodes are sorted by (stake, pubkey) in descending order. - let stakes: Vec = nodes.iter().map(|node| node.stake).collect(); - WeightedShuffle::new(&stakes) - .unwrap() - .shuffle(rng) - .map(|i| nodes[i]) - .collect() -} - impl ClusterNodesCache { pub fn new( // Capacity of underlying LRU-cache in terms of number of epochs. @@ -494,18 +459,6 @@ impl From for NodeId { } } -impl Default for ClusterNodes { - fn default() -> Self { - Self { - pubkey: Pubkey::default(), - nodes: Vec::default(), - cumulative_weights: Vec::default(), - index: Vec::default(), - _phantom: PhantomData::default(), - } - } -} - #[cfg(test)] mod tests { use { diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index f1d73e4168..52da902fbc 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -28,6 +28,7 @@ pub enum WeightedShuffleError { /// weight. /// - Zero weighted indices are shuffled and appear only at the end, after /// non-zero weighted indices. +#[derive(Clone)] pub struct WeightedShuffle { arr: Vec, // Underlying array implementing binary indexed tree. sum: T, // Current sum of weights, excluding already selected indices. @@ -118,6 +119,36 @@ where debug_assert!(hi.1 > val); (hi.0, hi.1 - lo.1) } + + pub fn remove_index(&mut self, index: usize) { + let zero = ::default(); + let weight = self.cumsum(index + 1) - self.cumsum(index); + if weight != zero { + self.remove(index + 1, weight); + } else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) { + self.zeros.remove(index); + } + } +} + +impl WeightedShuffle +where + T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, +{ + // Equivalent to weighted_shuffle.shuffle(&mut rng).next() + pub fn first(&self, rng: &mut R) -> Option { + let zero = ::default(); + if self.sum > zero { + let sample = ::Sampler::sample_single(zero, self.sum, rng); + let (index, _weight) = WeightedShuffle::search(self, sample); + return Some(index - 1); + } + if self.zeros.is_empty() { + return None; + } + let index = ::Sampler::sample_single(0usize, self.zeros.len(), rng); + self.zeros.get(index).copied() + } } impl<'a, T: 'a> WeightedShuffle @@ -143,39 +174,6 @@ where } } -// Equivalent to WeightedShuffle(rng, weights).unwrap().next(). -pub fn weighted_sample_single(rng: &mut R, cumulative_weights: &[T]) -> Option -where - T: Copy + Default + PartialOrd + SampleUniform, -{ - let zero = ::default(); - let high = cumulative_weights.last().copied().unwrap_or_default(); - if high == zero { - if cumulative_weights.is_empty() { - return None; - } - let index = - ::Sampler::sample_single(0usize, cumulative_weights.len(), rng); - return Some(index); - } - let sample = ::Sampler::sample_single(zero, high, rng); - let mut lo = 0usize; - let mut hi = cumulative_weights.len() - 1; - while lo + 1 < hi { - let k = lo + (hi - lo) / 2; - if cumulative_weights[k] <= sample { - lo = k; - } else { - hi = k; - } - } - if cumulative_weights[lo] > sample { - Some(lo) - } else { - Some(hi) - } -} - /// Returns a list of indexes shuffled based on the input weights /// Note - The sum of all weights must not exceed `u64::MAX` pub fn weighted_shuffle(weights: F, seed: [u8; 32]) -> Vec @@ -346,8 +344,8 @@ mod tests { let weights = Vec::::new(); let mut rng = rand::thread_rng(); let shuffle = WeightedShuffle::new(&weights).unwrap(); - assert!(shuffle.shuffle(&mut rng).next().is_none()); - assert_eq!(weighted_sample_single(&mut rng, &weights), None); + assert!(shuffle.clone().shuffle(&mut rng).next().is_none()); + assert!(shuffle.first(&mut rng).is_none()); } // Asserts that zero weights will be shuffled. @@ -357,9 +355,12 @@ mod tests { let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); let shuffle = WeightedShuffle::new(&weights).unwrap(); - let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); - assert_eq!(shuffle, [1, 4, 2, 3, 0]); - assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1)); + assert_eq!( + shuffle.clone().shuffle(&mut rng).collect::>(), + [1, 4, 2, 3, 0] + ); + let mut rng = ChaChaRng::from_seed(seed); + assert_eq!(shuffle.first(&mut rng), Some(1)); } // Asserts that each index is selected proportional to its weight. @@ -376,6 +377,17 @@ mod tests { let _ = shuffle.count(); // consume the rest. } assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]); + let mut counts = [0; 8]; + for _ in 0..100000 { + let mut shuffle = WeightedShuffle::new(&weights).unwrap(); + shuffle.remove_index(5); + shuffle.remove_index(3); + shuffle.remove_index(1); + let mut shuffle = shuffle.shuffle(&mut rng); + counts[shuffle.next().unwrap()] += 1; + let _ = shuffle.count(); // consume the rest. + } + assert_eq!(counts, [97, 0, 90862, 0, 0, 0, 9041, 0]); } #[test] @@ -383,52 +395,54 @@ mod tests { let weights = [ 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72, ]; - let cumulative_weights: Vec<_> = weights - .iter() - .scan(0, |acc, w| { - *acc += w; - Some(*acc) - }) - .collect(); let seed = [48u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let shuffle = WeightedShuffle::new(&weights).unwrap(); - let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); + let mut shuffle = WeightedShuffle::new(&weights).unwrap(); assert_eq!( - shuffle, + shuffle.clone().shuffle(&mut rng).collect::>(), [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] ); let mut rng = ChaChaRng::from_seed(seed); + assert_eq!(shuffle.first(&mut rng), Some(2)); + let mut rng = ChaChaRng::from_seed(seed); + shuffle.remove_index(11); + shuffle.remove_index(3); + shuffle.remove_index(15); + shuffle.remove_index(0); assert_eq!( - weighted_sample_single(&mut rng, &cumulative_weights), - Some(2), + shuffle.clone().shuffle(&mut rng).collect::>(), + [4, 6, 1, 12, 19, 14, 17, 20, 2, 9, 10, 8, 7, 18, 13, 5, 16] ); + let mut rng = ChaChaRng::from_seed(seed); + assert_eq!(shuffle.first(&mut rng), Some(4)); let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let shuffle = WeightedShuffle::new(&weights).unwrap(); - let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); + let mut shuffle = WeightedShuffle::new(&weights).unwrap(); assert_eq!( - shuffle, + shuffle.clone().shuffle(&mut rng).collect::>(), [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] ); let mut rng = ChaChaRng::from_seed(seed); + assert_eq!(shuffle.first(&mut rng), Some(19)); + shuffle.remove_index(16); + shuffle.remove_index(8); + shuffle.remove_index(20); + shuffle.remove_index(5); + shuffle.remove_index(19); + shuffle.remove_index(4); + let mut rng = ChaChaRng::from_seed(seed); assert_eq!( - weighted_sample_single(&mut rng, &cumulative_weights), - Some(19), + shuffle.clone().shuffle(&mut rng).collect::>(), + [17, 2, 9, 14, 6, 10, 12, 1, 15, 13, 7, 0, 18, 3, 11] ); + let mut rng = ChaChaRng::from_seed(seed); + assert_eq!(shuffle.first(&mut rng), Some(17)); } #[test] fn test_weighted_shuffle_match_slow() { let mut rng = rand::thread_rng(); let weights: Vec = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect(); - let cumulative_weights: Vec<_> = weights - .iter() - .scan(0, |acc, w| { - *acc += w; - Some(*acc) - }) - .collect(); for _ in 0..10 { let mut seed = [0u8; 32]; rng.fill(&mut seed[..]); @@ -439,10 +453,8 @@ mod tests { let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); assert_eq!(shuffle, shuffle_slow); let mut rng = ChaChaRng::from_seed(seed); - assert_eq!( - weighted_sample_single(&mut rng, &cumulative_weights), - Some(shuffle[0]), - ); + let shuffle = WeightedShuffle::new(&weights).unwrap(); + assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0])); } } }