diff --git a/core/src/unprocessed_packet_batches.rs b/core/src/unprocessed_packet_batches.rs index 2a94475408..c891b1185f 100644 --- a/core/src/unprocessed_packet_batches.rs +++ b/core/src/unprocessed_packet_batches.rs @@ -37,6 +37,17 @@ pub struct DeserializedPacketBatch { pub unprocessed_packets: HashMap, } +/// References to a packet in `UnprocessedPacketBatches`, where +/// - batch_index references to `DeserializedPacketBatch`, +/// - packet_index references to `packet` within `DeserializedPacketBatch.packet_batch` +#[derive(Debug, Default)] +pub struct PacketLocator { + #[allow(dead_code)] + batch_index: usize, + #[allow(dead_code)] + packet_index: usize, +} + /// Currently each banking_stage thread has a `UnprocessedPacketBatches` buffer to store /// PacketBatch's received from sigverify. Banking thread continuously scans the buffer /// to pick proper packets to add to the block. @@ -79,6 +90,50 @@ impl UnprocessedPacketBatches { pub fn with_capacity(capacity: usize) -> Self { UnprocessedPacketBatches(VecDeque::with_capacity(capacity)) } + + /// Returns total number of all packets (including unprocessed and processed) in buffer + #[allow(dead_code)] + fn get_packets_count(&self) -> usize { + self.iter() + .map(|deserialized_packet_batch| deserialized_packet_batch.packet_batch.packets.len()) + .sum() + } + + /// Returns total number of unprocessed packets in buffer + #[allow(dead_code)] + fn get_unprocessed_packets_count(&self) -> usize { + self.iter() + .map(|deserialized_packet_batch| deserialized_packet_batch.unprocessed_packets.len()) + .sum() + } + + /// Iterates packets in buffered packet_batches, returns all unprocessed packet's stake, + /// and its locator + #[allow(dead_code)] + fn get_stakes_and_locators(&self) -> (Vec, Vec) { + let num_unprocessed_packets = self.get_unprocessed_packets_count(); + let mut stakes = Vec::::with_capacity(num_unprocessed_packets); + let mut locators = Vec::::with_capacity(num_unprocessed_packets); + + self.iter() + .enumerate() + .for_each(|(batch_index, deserialized_packet_batch)| { + let packet_batch = &deserialized_packet_batch.packet_batch; + deserialized_packet_batch + .unprocessed_packets + .keys() + .for_each(|packet_index| { + let p = &packet_batch.packets[*packet_index]; + stakes.push(p.meta.sender_stake); + locators.push(PacketLocator { + batch_index, + packet_index: *packet_index, + }); + }); + }); + + (stakes, locators) + } } impl DeserializedPacketBatch { @@ -135,8 +190,8 @@ impl DeserializedPacketBatch { Some(&packet.data[msg_start..msg_end]) } - // Returns whether the given `PacketBatch` has any more remaining unprocessed - // transactions + /// Returns whether the given `PacketBatch` has any more remaining unprocessed + /// transactions pub fn update_buffered_packets_with_new_unprocessed( &mut self, _original_unprocessed_indexes: &[usize], @@ -159,8 +214,24 @@ mod tests { use { super::*, solana_sdk::{signature::Keypair, system_transaction}, + std::net::IpAddr, }; + fn packet_with_sender_stake(sender_stake: u64, ip: Option) -> Packet { + let tx = system_transaction::transfer( + &Keypair::new(), + &solana_sdk::pubkey::new_rand(), + 1, + Hash::new_unique(), + ); + let mut packet = Packet::from_data(None, &tx).unwrap(); + packet.meta.sender_stake = sender_stake; + if let Some(ip) = ip { + packet.meta.addr = ip; + } + packet + } + #[test] fn test_packet_message() { let keypair = Keypair::new(); @@ -175,4 +246,92 @@ mod tests { transaction.message_data() ); } + + #[test] + fn test_get_packets_count() { + // create a buffer with 3 batches, each has 2 packets but only first one is valid + let batch_size = 2usize; + let batch_count = 3usize; + let unprocessed_packet_batches: UnprocessedPacketBatches = (0..batch_count) + .map(|_batch_index| { + DeserializedPacketBatch::new( + PacketBatch::new( + (0..batch_size) + .map(|packet_index| packet_with_sender_stake(packet_index as u64, None)) + .collect(), + ), + vec![0], + false, + ) + }) + .collect(); + + // Assert total packets count, and unprocessed packets count + assert_eq!( + batch_size * batch_count, + unprocessed_packet_batches.get_packets_count() + ); + assert_eq!( + batch_count, + unprocessed_packet_batches.get_unprocessed_packets_count() + ); + } + + #[test] + fn test_get_stakes_and_locators_from_empty_buffer() { + let unprocessed_packet_batches = UnprocessedPacketBatches::default(); + let (stakes, locators) = unprocessed_packet_batches.get_stakes_and_locators(); + + assert!(stakes.is_empty()); + assert!(locators.is_empty()); + } + + #[test] + fn test_get_stakes_and_locators() { + solana_logger::setup(); + + // setup senders' address and stake + let senders: Vec<(IpAddr, u64)> = vec![ + (IpAddr::from([127, 0, 0, 1]), 1), + (IpAddr::from([127, 0, 0, 2]), 2), + (IpAddr::from([127, 0, 0, 3]), 3), + ]; + // create a buffer with 3 batches, each has 2 packet from above sender. + // buffer looks like: + // [127.0.0.1, 127.0.0.2] + // [127.0.0.3, 127.0.0.1] + // [127.0.0.2, 127.0.0.3] + let batch_size = 2usize; + let batch_count = 3usize; + let unprocessed_packet_batches: UnprocessedPacketBatches = (0..batch_count) + .map(|batch_index| { + DeserializedPacketBatch::new( + PacketBatch::new( + (0..batch_size) + .map(|packet_index| { + let n = (batch_index * batch_size + packet_index) % senders.len(); + packet_with_sender_stake(senders[n].1, Some(senders[n].0)) + }) + .collect(), + ), + (0..batch_size).collect(), + false, + ) + }) + .collect(); + + let (stakes, locators) = unprocessed_packet_batches.get_stakes_and_locators(); + + // Produced stakes and locators should both have "batch_size * batch_count" entries; + assert_eq!(batch_size * batch_count, stakes.len()); + assert_eq!(batch_size * batch_count, locators.len()); + // Assert stakes and locators are in good order + locators.iter().enumerate().for_each(|(index, locator)| { + assert_eq!( + stakes[index], + senders[(locator.batch_index * batch_size + locator.packet_index) % senders.len()] + .1 + ); + }); + } }