caches WeightedShuffle struct in ClusterNodes (#22877)
Instead of reconstructing WeightedShuffle struct for each shred broadcast or retransmit, we can use the same struct with minimal mutations.
This commit is contained in:
@ -2,15 +2,13 @@ use {
|
|||||||
crate::{broadcast_stage::BroadcastStage, retransmit_stage::RetransmitStage},
|
crate::{broadcast_stage::BroadcastStage, retransmit_stage::RetransmitStage},
|
||||||
itertools::Itertools,
|
itertools::Itertools,
|
||||||
lru::LruCache,
|
lru::LruCache,
|
||||||
rand::{Rng, SeedableRng},
|
rand::SeedableRng,
|
||||||
rand_chacha::ChaChaRng,
|
rand_chacha::ChaChaRng,
|
||||||
solana_gossip::{
|
solana_gossip::{
|
||||||
cluster_info::{compute_retransmit_peers, ClusterInfo},
|
cluster_info::{compute_retransmit_peers, ClusterInfo},
|
||||||
contact_info::ContactInfo,
|
contact_info::ContactInfo,
|
||||||
crds_gossip_pull::CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS,
|
crds_gossip_pull::CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS,
|
||||||
weighted_shuffle::{
|
weighted_shuffle::{weighted_best, weighted_shuffle, WeightedShuffle},
|
||||||
weighted_best, weighted_sample_single, weighted_shuffle, WeightedShuffle,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
solana_ledger::shred::Shred,
|
solana_ledger::shred::Shred,
|
||||||
solana_runtime::bank::Bank,
|
solana_runtime::bank::Bank,
|
||||||
@ -51,9 +49,7 @@ pub struct ClusterNodes<T> {
|
|||||||
// All staked nodes + other known tvu-peers + the node itself;
|
// All staked nodes + other known tvu-peers + the node itself;
|
||||||
// sorted by (stake, pubkey) in descending order.
|
// sorted by (stake, pubkey) in descending order.
|
||||||
nodes: Vec<Node>,
|
nodes: Vec<Node>,
|
||||||
// Cumulative stakes (excluding the node itself), used for sampling
|
weighted_shuffle: WeightedShuffle</*stake:*/ u64>,
|
||||||
// broadcast peers.
|
|
||||||
cumulative_weights: Vec<u64>,
|
|
||||||
// Weights and indices for sampling peers. weighted_{shuffle,best} expect
|
// Weights and indices for sampling peers. weighted_{shuffle,best} expect
|
||||||
// weights >= 1. For backward compatibility we use max(1, stake) for
|
// weights >= 1. For backward compatibility we use max(1, stake) for
|
||||||
// weights and exclude nodes with no contact-info.
|
// weights and exclude nodes with no contact-info.
|
||||||
@ -133,7 +129,7 @@ impl ClusterNodes<BroadcastStage> {
|
|||||||
return Vec::default();
|
return Vec::default();
|
||||||
}
|
}
|
||||||
let mut rng = ChaChaRng::from_seed(shred_seed);
|
let mut rng = ChaChaRng::from_seed(shred_seed);
|
||||||
let index = match weighted_sample_single(&mut rng, &self.cumulative_weights) {
|
let index = match self.weighted_shuffle.first(&mut rng) {
|
||||||
None => return Vec::default(),
|
None => return Vec::default(),
|
||||||
Some(index) => index,
|
Some(index) => index,
|
||||||
};
|
};
|
||||||
@ -146,16 +142,16 @@ impl ClusterNodes<BroadcastStage> {
|
|||||||
return vec![node.tvu];
|
return vec![node.tvu];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let nodes: Vec<_> = self
|
let mut rng = ChaChaRng::from_seed(shred_seed);
|
||||||
.nodes
|
let nodes: Vec<&Node> = self
|
||||||
.iter()
|
.weighted_shuffle
|
||||||
.filter(|node| node.pubkey() != self.pubkey)
|
.clone()
|
||||||
|
.shuffle(&mut rng)
|
||||||
|
.map(|index| &self.nodes[index])
|
||||||
.collect();
|
.collect();
|
||||||
if nodes.is_empty() {
|
if nodes.is_empty() {
|
||||||
return Vec::default();
|
return Vec::default();
|
||||||
}
|
}
|
||||||
let mut rng = ChaChaRng::from_seed(shred_seed);
|
|
||||||
let nodes = shuffle_nodes(&mut rng, &nodes);
|
|
||||||
let (neighbors, children) = compute_retransmit_peers(fanout, 0, &nodes);
|
let (neighbors, children) = compute_retransmit_peers(fanout, 0, &nodes);
|
||||||
neighbors[..1]
|
neighbors[..1]
|
||||||
.iter()
|
.iter()
|
||||||
@ -235,18 +231,22 @@ impl ClusterNodes<RetransmitStage> {
|
|||||||
if !enable_turbine_peers_shuffle_patch(shred.slot(), root_bank) {
|
if !enable_turbine_peers_shuffle_patch(shred.slot(), root_bank) {
|
||||||
return self.get_retransmit_peers_compat(shred_seed, fanout, slot_leader);
|
return self.get_retransmit_peers_compat(shred_seed, fanout, slot_leader);
|
||||||
}
|
}
|
||||||
|
let mut weighted_shuffle = self.weighted_shuffle.clone();
|
||||||
// Exclude slot leader from list of nodes.
|
// Exclude slot leader from list of nodes.
|
||||||
let nodes: Vec<_> = if slot_leader == self.pubkey {
|
if slot_leader == self.pubkey {
|
||||||
error!("retransmit from slot leader: {}", slot_leader);
|
error!("retransmit from slot leader: {}", slot_leader);
|
||||||
self.nodes.iter().collect()
|
} else if let Some(index) = self
|
||||||
} else {
|
.nodes
|
||||||
self.nodes
|
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|node| node.pubkey() != slot_leader)
|
.position(|node| node.pubkey() == slot_leader)
|
||||||
.collect()
|
{
|
||||||
|
weighted_shuffle.remove_index(index);
|
||||||
};
|
};
|
||||||
let mut rng = ChaChaRng::from_seed(shred_seed);
|
let mut rng = ChaChaRng::from_seed(shred_seed);
|
||||||
let nodes = shuffle_nodes(&mut rng, &nodes);
|
let nodes: Vec<_> = weighted_shuffle
|
||||||
|
.shuffle(&mut rng)
|
||||||
|
.map(|index| &self.nodes[index])
|
||||||
|
.collect();
|
||||||
let self_index = nodes
|
let self_index = nodes
|
||||||
.iter()
|
.iter()
|
||||||
.position(|node| node.pubkey() == self.pubkey)
|
.position(|node| node.pubkey() == self.pubkey)
|
||||||
@ -299,30 +299,6 @@ impl ClusterNodes<RetransmitStage> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_cumulative_weights(self_pubkey: Pubkey, nodes: &[Node]) -> Vec<u64> {
|
|
||||||
let cumulative_stakes: Vec<_> = nodes
|
|
||||||
.iter()
|
|
||||||
.scan(0, |acc, node| {
|
|
||||||
if node.pubkey() != self_pubkey {
|
|
||||||
*acc += node.stake;
|
|
||||||
}
|
|
||||||
Some(*acc)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
if cumulative_stakes.last() != Some(&0) {
|
|
||||||
return cumulative_stakes;
|
|
||||||
}
|
|
||||||
nodes
|
|
||||||
.iter()
|
|
||||||
.scan(0, |acc, node| {
|
|
||||||
if node.pubkey() != self_pubkey {
|
|
||||||
*acc += 1;
|
|
||||||
}
|
|
||||||
Some(*acc)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_cluster_nodes<T: 'static>(
|
fn new_cluster_nodes<T: 'static>(
|
||||||
cluster_info: &ClusterInfo,
|
cluster_info: &ClusterInfo,
|
||||||
stakes: &HashMap<Pubkey, u64>,
|
stakes: &HashMap<Pubkey, u64>,
|
||||||
@ -330,11 +306,12 @@ fn new_cluster_nodes<T: 'static>(
|
|||||||
let self_pubkey = cluster_info.id();
|
let self_pubkey = cluster_info.id();
|
||||||
let nodes = get_nodes(cluster_info, stakes);
|
let nodes = get_nodes(cluster_info, stakes);
|
||||||
let broadcast = TypeId::of::<T>() == TypeId::of::<BroadcastStage>();
|
let broadcast = TypeId::of::<T>() == TypeId::of::<BroadcastStage>();
|
||||||
let cumulative_weights = if broadcast {
|
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
|
||||||
build_cumulative_weights(self_pubkey, &nodes)
|
let mut weighted_shuffle = WeightedShuffle::new(&stakes).unwrap();
|
||||||
} else {
|
if broadcast {
|
||||||
Vec::default()
|
let index = nodes.iter().position(|node| node.pubkey() == self_pubkey);
|
||||||
};
|
weighted_shuffle.remove_index(index.unwrap());
|
||||||
|
}
|
||||||
// For backward compatibility:
|
// For backward compatibility:
|
||||||
// * nodes which do not have contact-info are excluded.
|
// * nodes which do not have contact-info are excluded.
|
||||||
// * stakes are floored at 1.
|
// * stakes are floored at 1.
|
||||||
@ -352,7 +329,7 @@ fn new_cluster_nodes<T: 'static>(
|
|||||||
ClusterNodes {
|
ClusterNodes {
|
||||||
pubkey: self_pubkey,
|
pubkey: self_pubkey,
|
||||||
nodes,
|
nodes,
|
||||||
cumulative_weights,
|
weighted_shuffle,
|
||||||
index,
|
index,
|
||||||
_phantom: PhantomData::default(),
|
_phantom: PhantomData::default(),
|
||||||
}
|
}
|
||||||
@ -406,18 +383,6 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shuffles nodes w.r.t their stakes.
|
|
||||||
// Unstaked nodes will always appear at the very end.
|
|
||||||
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
|
|
||||||
// Nodes are sorted by (stake, pubkey) in descending order.
|
|
||||||
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
|
|
||||||
WeightedShuffle::new(&stakes)
|
|
||||||
.unwrap()
|
|
||||||
.shuffle(rng)
|
|
||||||
.map(|i| nodes[i])
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> ClusterNodesCache<T> {
|
impl<T> ClusterNodesCache<T> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
// Capacity of underlying LRU-cache in terms of number of epochs.
|
// Capacity of underlying LRU-cache in terms of number of epochs.
|
||||||
@ -494,18 +459,6 @@ impl From<Pubkey> for NodeId {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Default for ClusterNodes<T> {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
pubkey: Pubkey::default(),
|
|
||||||
nodes: Vec::default(),
|
|
||||||
cumulative_weights: Vec::default(),
|
|
||||||
index: Vec::default(),
|
|
||||||
_phantom: PhantomData::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use {
|
use {
|
||||||
|
@ -28,6 +28,7 @@ pub enum WeightedShuffleError<T> {
|
|||||||
/// weight.
|
/// weight.
|
||||||
/// - Zero weighted indices are shuffled and appear only at the end, after
|
/// - Zero weighted indices are shuffled and appear only at the end, after
|
||||||
/// non-zero weighted indices.
|
/// non-zero weighted indices.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct WeightedShuffle<T> {
|
pub struct WeightedShuffle<T> {
|
||||||
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
||||||
sum: T, // Current sum of weights, excluding already selected indices.
|
sum: T, // Current sum of weights, excluding already selected indices.
|
||||||
@ -118,6 +119,36 @@ where
|
|||||||
debug_assert!(hi.1 > val);
|
debug_assert!(hi.1 > val);
|
||||||
(hi.0, hi.1 - lo.1)
|
(hi.0, hi.1 - lo.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn remove_index(&mut self, index: usize) {
|
||||||
|
let zero = <T as Default>::default();
|
||||||
|
let weight = self.cumsum(index + 1) - self.cumsum(index);
|
||||||
|
if weight != zero {
|
||||||
|
self.remove(index + 1, weight);
|
||||||
|
} else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) {
|
||||||
|
self.zeros.remove(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> WeightedShuffle<T>
|
||||||
|
where
|
||||||
|
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
|
||||||
|
{
|
||||||
|
// Equivalent to weighted_shuffle.shuffle(&mut rng).next()
|
||||||
|
pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
|
||||||
|
let zero = <T as Default>::default();
|
||||||
|
if self.sum > zero {
|
||||||
|
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
|
||||||
|
let (index, _weight) = WeightedShuffle::search(self, sample);
|
||||||
|
return Some(index - 1);
|
||||||
|
}
|
||||||
|
if self.zeros.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let index = <usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
|
||||||
|
self.zeros.get(index).copied()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: 'a> WeightedShuffle<T>
|
impl<'a, T: 'a> WeightedShuffle<T>
|
||||||
@ -143,39 +174,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Equivalent to WeightedShuffle(rng, weights).unwrap().next().
|
|
||||||
pub fn weighted_sample_single<R: Rng, T>(rng: &mut R, cumulative_weights: &[T]) -> Option<usize>
|
|
||||||
where
|
|
||||||
T: Copy + Default + PartialOrd + SampleUniform,
|
|
||||||
{
|
|
||||||
let zero = <T as Default>::default();
|
|
||||||
let high = cumulative_weights.last().copied().unwrap_or_default();
|
|
||||||
if high == zero {
|
|
||||||
if cumulative_weights.is_empty() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let index =
|
|
||||||
<usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng);
|
|
||||||
return Some(index);
|
|
||||||
}
|
|
||||||
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
|
|
||||||
let mut lo = 0usize;
|
|
||||||
let mut hi = cumulative_weights.len() - 1;
|
|
||||||
while lo + 1 < hi {
|
|
||||||
let k = lo + (hi - lo) / 2;
|
|
||||||
if cumulative_weights[k] <= sample {
|
|
||||||
lo = k;
|
|
||||||
} else {
|
|
||||||
hi = k;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if cumulative_weights[lo] > sample {
|
|
||||||
Some(lo)
|
|
||||||
} else {
|
|
||||||
Some(hi)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a list of indexes shuffled based on the input weights
|
/// Returns a list of indexes shuffled based on the input weights
|
||||||
/// Note - The sum of all weights must not exceed `u64::MAX`
|
/// Note - The sum of all weights must not exceed `u64::MAX`
|
||||||
pub fn weighted_shuffle<T, B, F>(weights: F, seed: [u8; 32]) -> Vec<usize>
|
pub fn weighted_shuffle<T, B, F>(weights: F, seed: [u8; 32]) -> Vec<usize>
|
||||||
@ -346,8 +344,8 @@ mod tests {
|
|||||||
let weights = Vec::<u64>::new();
|
let weights = Vec::<u64>::new();
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||||
assert!(shuffle.shuffle(&mut rng).next().is_none());
|
assert!(shuffle.clone().shuffle(&mut rng).next().is_none());
|
||||||
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
assert!(shuffle.first(&mut rng).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Asserts that zero weights will be shuffled.
|
// Asserts that zero weights will be shuffled.
|
||||||
@ -357,9 +355,12 @@ mod tests {
|
|||||||
let seed = [37u8; 32];
|
let seed = [37u8; 32];
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||||
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
assert_eq!(
|
||||||
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
|
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
||||||
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
|
[1, 4, 2, 3, 0]
|
||||||
|
);
|
||||||
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
|
assert_eq!(shuffle.first(&mut rng), Some(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Asserts that each index is selected proportional to its weight.
|
// Asserts that each index is selected proportional to its weight.
|
||||||
@ -376,6 +377,17 @@ mod tests {
|
|||||||
let _ = shuffle.count(); // consume the rest.
|
let _ = shuffle.count(); // consume the rest.
|
||||||
}
|
}
|
||||||
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
|
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
|
||||||
|
let mut counts = [0; 8];
|
||||||
|
for _ in 0..100000 {
|
||||||
|
let mut shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||||
|
shuffle.remove_index(5);
|
||||||
|
shuffle.remove_index(3);
|
||||||
|
shuffle.remove_index(1);
|
||||||
|
let mut shuffle = shuffle.shuffle(&mut rng);
|
||||||
|
counts[shuffle.next().unwrap()] += 1;
|
||||||
|
let _ = shuffle.count(); // consume the rest.
|
||||||
|
}
|
||||||
|
assert_eq!(counts, [97, 0, 90862, 0, 0, 0, 9041, 0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -383,52 +395,54 @@ mod tests {
|
|||||||
let weights = [
|
let weights = [
|
||||||
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
|
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
|
||||||
];
|
];
|
||||||
let cumulative_weights: Vec<_> = weights
|
|
||||||
.iter()
|
|
||||||
.scan(0, |acc, w| {
|
|
||||||
*acc += w;
|
|
||||||
Some(*acc)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let seed = [48u8; 32];
|
let seed = [48u8; 32];
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
let mut shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||||
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
shuffle,
|
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
||||||
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
|
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
|
||||||
);
|
);
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
|
assert_eq!(shuffle.first(&mut rng), Some(2));
|
||||||
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
|
shuffle.remove_index(11);
|
||||||
|
shuffle.remove_index(3);
|
||||||
|
shuffle.remove_index(15);
|
||||||
|
shuffle.remove_index(0);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
weighted_sample_single(&mut rng, &cumulative_weights),
|
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
||||||
Some(2),
|
[4, 6, 1, 12, 19, 14, 17, 20, 2, 9, 10, 8, 7, 18, 13, 5, 16]
|
||||||
);
|
);
|
||||||
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
|
assert_eq!(shuffle.first(&mut rng), Some(4));
|
||||||
let seed = [37u8; 32];
|
let seed = [37u8; 32];
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
let mut shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||||
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
shuffle,
|
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
||||||
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
|
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
|
||||||
);
|
);
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
|
assert_eq!(shuffle.first(&mut rng), Some(19));
|
||||||
|
shuffle.remove_index(16);
|
||||||
|
shuffle.remove_index(8);
|
||||||
|
shuffle.remove_index(20);
|
||||||
|
shuffle.remove_index(5);
|
||||||
|
shuffle.remove_index(19);
|
||||||
|
shuffle.remove_index(4);
|
||||||
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
weighted_sample_single(&mut rng, &cumulative_weights),
|
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
||||||
Some(19),
|
[17, 2, 9, 14, 6, 10, 12, 1, 15, 13, 7, 0, 18, 3, 11]
|
||||||
);
|
);
|
||||||
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
|
assert_eq!(shuffle.first(&mut rng), Some(17));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_weighted_shuffle_match_slow() {
|
fn test_weighted_shuffle_match_slow() {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let weights: Vec<u64> = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect();
|
let weights: Vec<u64> = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect();
|
||||||
let cumulative_weights: Vec<_> = weights
|
|
||||||
.iter()
|
|
||||||
.scan(0, |acc, w| {
|
|
||||||
*acc += w;
|
|
||||||
Some(*acc)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
let mut seed = [0u8; 32];
|
let mut seed = [0u8; 32];
|
||||||
rng.fill(&mut seed[..]);
|
rng.fill(&mut seed[..]);
|
||||||
@ -439,10 +453,8 @@ mod tests {
|
|||||||
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
|
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
|
||||||
assert_eq!(shuffle, shuffle_slow);
|
assert_eq!(shuffle, shuffle_slow);
|
||||||
let mut rng = ChaChaRng::from_seed(seed);
|
let mut rng = ChaChaRng::from_seed(seed);
|
||||||
assert_eq!(
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||||
weighted_sample_single(&mut rng, &cumulative_weights),
|
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
|
||||||
Some(shuffle[0]),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user