Refactor Weighted Shuffle (#6614)

automerge
This commit is contained in:
Sagar Dhawan
2019-10-29 21:02:11 -07:00
committed by Grimes
parent 4ec95043d7
commit 801337a422
4 changed files with 20 additions and 24 deletions

View File

@ -26,9 +26,7 @@ use crate::weighted_shuffle::{weighted_best, weighted_shuffle};
use bincode::{deserialize, serialize, serialized_size}; use bincode::{deserialize, serialize, serialized_size};
use core::cmp; use core::cmp;
use itertools::Itertools; use itertools::Itertools;
use rand::SeedableRng;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use rand_chacha::ChaChaRng;
use solana_ledger::bank_forks::BankForks; use solana_ledger::bank_forks::BankForks;
use solana_ledger::blocktree::Blocktree; use solana_ledger::blocktree::Blocktree;
use solana_ledger::staking_utils; use solana_ledger::staking_utils;
@ -510,11 +508,11 @@ impl ClusterInfo {
fn stake_weighted_shuffle( fn stake_weighted_shuffle(
stakes_and_index: &[(u64, usize)], stakes_and_index: &[(u64, usize)],
rng: ChaChaRng, seed: [u8; 32],
) -> Vec<(u64, usize)> { ) -> Vec<(u64, usize)> {
let stake_weights = stakes_and_index.iter().map(|(w, _)| *w).collect(); let stake_weights = stakes_and_index.iter().map(|(w, _)| *w).collect();
let shuffle = weighted_shuffle(stake_weights, rng); let shuffle = weighted_shuffle(stake_weights, seed);
shuffle.iter().map(|x| stakes_and_index[*x]).collect() shuffle.iter().map(|x| stakes_and_index[*x]).collect()
} }
@ -536,9 +534,9 @@ impl ClusterInfo {
id: &Pubkey, id: &Pubkey,
peers: &[ContactInfo], peers: &[ContactInfo],
stakes_and_index: &[(u64, usize)], stakes_and_index: &[(u64, usize)],
rng: ChaChaRng, seed: [u8; 32],
) -> (usize, Vec<(u64, usize)>) { ) -> (usize, Vec<(u64, usize)>) {
let shuffled_stakes_and_index = ClusterInfo::stake_weighted_shuffle(stakes_and_index, rng); let shuffled_stakes_and_index = ClusterInfo::stake_weighted_shuffle(stakes_and_index, seed);
let mut self_index = 0; let mut self_index = 0;
shuffled_stakes_and_index shuffled_stakes_and_index
.iter() .iter()
@ -723,7 +721,7 @@ impl ClusterInfo {
.into_iter() .into_iter()
.zip(seeds) .zip(seeds)
.map(|(shred, seed)| { .map(|(shred, seed)| {
let broadcast_index = weighted_best(&peers_and_stakes, ChaChaRng::from_seed(*seed)); let broadcast_index = weighted_best(&peers_and_stakes, *seed);
(shred, &peers[broadcast_index].tvu) (shred, &peers[broadcast_index].tvu)
}) })

View File

@ -20,8 +20,7 @@ use indexmap::map::IndexMap;
use itertools::Itertools; use itertools::Itertools;
use rand; use rand;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::{thread_rng, RngCore, SeedableRng}; use rand::{thread_rng, RngCore};
use rand_chacha::ChaChaRng;
use solana_runtime::bloom::Bloom; use solana_runtime::bloom::Bloom;
use solana_sdk::hash::Hash; use solana_sdk::hash::Hash;
use solana_sdk::pubkey::Pubkey; use solana_sdk::pubkey::Pubkey;
@ -106,7 +105,7 @@ impl CrdsGossipPush {
seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes()); seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes());
let shuffle = weighted_shuffle( let shuffle = weighted_shuffle(
staked_peers.iter().map(|(_, stake)| *stake).collect_vec(), staked_peers.iter().map(|(_, stake)| *stake).collect_vec(),
ChaChaRng::from_seed(seed), seed,
); );
let mut keep = HashSet::new(); let mut keep = HashSet::new();
@ -244,7 +243,7 @@ impl CrdsGossipPush {
seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes()); seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes());
let mut shuffle = weighted_shuffle( let mut shuffle = weighted_shuffle(
options.iter().map(|weighted| weighted.0).collect_vec(), options.iter().map(|weighted| weighted.0).collect_vec(),
ChaChaRng::from_seed(seed), seed,
) )
.into_iter(); .into_iter();

View File

@ -10,8 +10,6 @@ use crate::{
window_service::{should_retransmit_and_persist, WindowService}, window_service::{should_retransmit_and_persist, WindowService},
}; };
use crossbeam_channel::Receiver as CrossbeamReceiver; use crossbeam_channel::Receiver as CrossbeamReceiver;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use solana_ledger::{ use solana_ledger::{
bank_forks::BankForks, bank_forks::BankForks,
blocktree::{Blocktree, CompletedSlotsReceiver}, blocktree::{Blocktree, CompletedSlotsReceiver},
@ -92,7 +90,7 @@ fn retransmit(
&me.id, &me.id,
&peers, &peers,
&stakes_and_index, &stakes_and_index,
ChaChaRng::from_seed(packet.meta.seed), packet.meta.seed,
); );
peers_len = cmp::max(peers_len, shuffled_stakes_and_index.len()); peers_len = cmp::max(peers_len, shuffled_stakes_and_index.len());
shuffled_stakes_and_index.remove(my_index); shuffled_stakes_and_index.remove(my_index);

View File

@ -2,18 +2,19 @@
use itertools::Itertools; use itertools::Itertools;
use num_traits::{FromPrimitive, ToPrimitive}; use num_traits::{FromPrimitive, ToPrimitive};
use rand::Rng; use rand::{Rng, SeedableRng};
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;
use std::iter; use std::iter;
use std::ops::Div; use std::ops::Div;
/// Returns a list of indexes shuffled based on the input weights /// Returns a list of indexes shuffled based on the input weights
/// Note - The sum of all weights must not exceed `u64::MAX` /// Note - The sum of all weights must not exceed `u64::MAX`
pub fn weighted_shuffle<T>(weights: Vec<T>, mut rng: ChaChaRng) -> Vec<usize> pub fn weighted_shuffle<T>(weights: Vec<T>, seed: [u8; 32]) -> Vec<usize>
where where
T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive, T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive,
{ {
let total_weight: T = weights.clone().into_iter().sum(); let total_weight: T = weights.clone().into_iter().sum();
let mut rng = ChaChaRng::from_seed(seed);
weights weights
.into_iter() .into_iter()
.enumerate() .enumerate()
@ -36,10 +37,11 @@ where
/// Returns the highest index after computing a weighted shuffle. /// Returns the highest index after computing a weighted shuffle.
/// Saves doing any sorting for O(n) max calculation. /// Saves doing any sorting for O(n) max calculation.
pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) -> usize { pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize {
if weights_and_indexes.is_empty() { if weights_and_indexes.is_empty() {
return 0; return 0;
} }
let mut rng = ChaChaRng::from_seed(seed);
let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum(); let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum();
let mut lowest_weight = std::u128::MAX; let mut lowest_weight = std::u128::MAX;
let mut best_index = 0; let mut best_index = 0;
@ -63,13 +65,12 @@ pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) -
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use rand::SeedableRng;
#[test] #[test]
fn test_weighted_shuffle_iterator() { fn test_weighted_shuffle_iterator() {
let mut test_set = [0; 6]; let mut test_set = [0; 6];
let mut count = 0; let mut count = 0;
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32])); let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]);
shuffle.into_iter().for_each(|x| { shuffle.into_iter().for_each(|x| {
assert_eq!(test_set[x], 0); assert_eq!(test_set[x], 0);
test_set[x] = 1; test_set[x] = 1;
@ -84,7 +85,7 @@ mod tests {
let mut test_weights = vec![0; 100]; let mut test_weights = vec![0; 100];
(0..100).for_each(|i| test_weights[i] = (i + 1) as u64); (0..100).for_each(|i| test_weights[i] = (i + 1) as u64);
let mut count = 0; let mut count = 0;
let shuffle = weighted_shuffle(test_weights, ChaChaRng::from_seed([0xa5; 32])); let shuffle = weighted_shuffle(test_weights, [0xa5; 32]);
shuffle.into_iter().for_each(|x| { shuffle.into_iter().for_each(|x| {
assert_eq!(test_set[x], 0); assert_eq!(test_set[x], 0);
test_set[x] = 1; test_set[x] = 1;
@ -95,9 +96,9 @@ mod tests {
#[test] #[test]
fn test_weighted_shuffle_compare() { fn test_weighted_shuffle_compare() {
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32])); let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]);
let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32])); let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]);
shuffle1 shuffle1
.into_iter() .into_iter()
.zip(shuffle.into_iter()) .zip(shuffle.into_iter())
@ -110,7 +111,7 @@ mod tests {
fn test_weighted_shuffle_imbalanced() { fn test_weighted_shuffle_imbalanced() {
let mut weights = vec![std::u32::MAX as u64; 3]; let mut weights = vec![std::u32::MAX as u64; 3];
weights.push(1); weights.push(1);
let shuffle = weighted_shuffle(weights.clone(), ChaChaRng::from_seed([0x5a; 32])); let shuffle = weighted_shuffle(weights.clone(), [0x5a; 32]);
shuffle.into_iter().for_each(|x| { shuffle.into_iter().for_each(|x| {
if x == weights.len() - 1 { if x == weights.len() - 1 {
assert_eq!(weights[x], 1); assert_eq!(weights[x], 1);
@ -127,7 +128,7 @@ mod tests {
.enumerate() .enumerate()
.map(|(i, weight)| (weight, i)) .map(|(i, weight)| (weight, i))
.collect(); .collect();
let best_index = weighted_best(&weights_and_indexes, ChaChaRng::from_seed([0x5b; 32])); let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]);
assert_eq!(best_index, 2); assert_eq!(best_index, 2);
} }
} }