From 5d17c2b58f20efb41102c24a45b4a480fc434702 Mon Sep 17 00:00:00 2001 From: Greg Fitzgerald Date: Mon, 2 Jul 2018 15:45:16 -0600 Subject: [PATCH] Return output receivers from each stage Reaching into the stages' structs for their receivers is, in hindsight, more awkward than returning multiple values from constructors. By returning the receiver, the caller can name the receiver whatever it wants (as you would with any return value), and doesn't need to reach into the struct for the field (which is super awkward in combination with move semantics). --- src/banking_stage.rs | 10 ++-------- src/blob_fetch_stage.rs | 14 +++++++------- src/fetch_stage.rs | 14 +++++++------- src/record_stage.rs | 38 +++++++++++++++++--------------------- src/request_stage.rs | 13 +++++++------ src/rpu.rs | 4 ++-- src/server.rs | 4 ++-- src/sigverify_stage.rs | 11 +++++------ src/tpu.rs | 33 ++++++++++++++------------------- src/tvu.rs | 8 ++++---- src/window_stage.rs | 8 ++------ src/write_stage.rs | 8 ++------ 12 files changed, 71 insertions(+), 94 deletions(-) diff --git a/src/banking_stage.rs b/src/banking_stage.rs index fddcf184f8..a2c8c7b8ce 100644 --- a/src/banking_stage.rs +++ b/src/banking_stage.rs @@ -23,9 +23,6 @@ use transaction::Transaction; pub struct BankingStage { /// Handle to the stage's thread. pub thread_hdl: JoinHandle<()>, - - /// Output receiver for the following stage. - pub signal_receiver: Receiver, } impl BankingStage { @@ -38,7 +35,7 @@ impl BankingStage { exit: Arc, verified_receiver: Receiver)>>, packet_recycler: PacketRecycler, - ) -> Self { + ) -> (Self, Receiver) { let (signal_sender, signal_receiver) = channel(); let thread_hdl = Builder::new() .name("solana-banking-stage".to_string()) @@ -56,10 +53,7 @@ impl BankingStage { } }) .unwrap(); - BankingStage { - thread_hdl, - signal_receiver, - } + (BankingStage { thread_hdl }, signal_receiver) } /// Convert the transactions from a blob of binary data to a vector of transactions and diff --git a/src/blob_fetch_stage.rs b/src/blob_fetch_stage.rs index f67839b985..550780f207 100644 --- a/src/blob_fetch_stage.rs +++ b/src/blob_fetch_stage.rs @@ -9,19 +9,22 @@ use std::thread::JoinHandle; use streamer::{self, BlobReceiver}; pub struct BlobFetchStage { - pub blob_receiver: BlobReceiver, pub thread_hdls: Vec>, } impl BlobFetchStage { - pub fn new(socket: UdpSocket, exit: Arc, blob_recycler: BlobRecycler) -> Self { + pub fn new( + socket: UdpSocket, + exit: Arc, + blob_recycler: BlobRecycler, + ) -> (Self, BlobReceiver) { Self::new_multi_socket(vec![socket], exit, blob_recycler) } pub fn new_multi_socket( sockets: Vec, exit: Arc, blob_recycler: BlobRecycler, - ) -> Self { + ) -> (Self, BlobReceiver) { let (blob_sender, blob_receiver) = channel(); let thread_hdls: Vec<_> = sockets .into_iter() @@ -35,9 +38,6 @@ impl BlobFetchStage { }) .collect(); - BlobFetchStage { - blob_receiver, - thread_hdls, - } + (BlobFetchStage { thread_hdls }, blob_receiver) } } diff --git a/src/fetch_stage.rs b/src/fetch_stage.rs index c73962eaea..bf117c1e38 100644 --- a/src/fetch_stage.rs +++ b/src/fetch_stage.rs @@ -9,19 +9,22 @@ use std::thread::JoinHandle; use streamer::{self, PacketReceiver}; pub struct FetchStage { - pub packet_receiver: PacketReceiver, pub thread_hdls: Vec>, } impl FetchStage { - pub fn new(socket: UdpSocket, exit: Arc, packet_recycler: PacketRecycler) -> Self { + pub fn new( + socket: UdpSocket, + exit: Arc, + packet_recycler: PacketRecycler, + ) -> (Self, PacketReceiver) { Self::new_multi_socket(vec![socket], exit, packet_recycler) } pub fn new_multi_socket( sockets: Vec, exit: Arc, packet_recycler: PacketRecycler, - ) -> Self { + ) -> (Self, PacketReceiver) { let (packet_sender, packet_receiver) = channel(); let thread_hdls: Vec<_> = sockets .into_iter() @@ -35,9 +38,6 @@ impl FetchStage { }) .collect(); - FetchStage { - packet_receiver, - thread_hdls, - } + (FetchStage { thread_hdls }, packet_receiver) } } diff --git a/src/record_stage.rs b/src/record_stage.rs index 1d20328836..4d50f258c4 100644 --- a/src/record_stage.rs +++ b/src/record_stage.rs @@ -20,14 +20,16 @@ pub enum Signal { } pub struct RecordStage { - pub entry_receiver: Receiver>, pub thread_hdl: JoinHandle<()>, } impl RecordStage { /// A background thread that will continue tagging received Transaction messages and /// sending back Entry messages until either the receiver or sender channel is closed. - pub fn new(signal_receiver: Receiver, start_hash: &Hash) -> Self { + pub fn new( + signal_receiver: Receiver, + start_hash: &Hash, + ) -> (Self, Receiver>) { let (entry_sender, entry_receiver) = channel(); let start_hash = start_hash.clone(); @@ -39,10 +41,7 @@ impl RecordStage { }) .unwrap(); - RecordStage { - entry_receiver, - thread_hdl, - } + (RecordStage { thread_hdl }, entry_receiver) } /// Same as `RecordStage::new`, but will automatically produce entries every `tick_duration`. @@ -50,7 +49,7 @@ impl RecordStage { signal_receiver: Receiver, start_hash: &Hash, tick_duration: Duration, - ) -> Self { + ) -> (Self, Receiver>) { let (entry_sender, entry_receiver) = channel(); let start_hash = start_hash.clone(); @@ -74,10 +73,7 @@ impl RecordStage { }) .unwrap(); - RecordStage { - entry_receiver, - thread_hdl, - } + (RecordStage { thread_hdl }, entry_receiver) } fn process_signal( @@ -140,7 +136,7 @@ mod tests { fn test_historian() { let (tx_sender, tx_receiver) = channel(); let zero = Hash::default(); - let record_stage = RecordStage::new(tx_receiver, &zero); + let (record_stage, entry_receiver) = RecordStage::new(tx_receiver, &zero); tx_sender.send(Signal::Tick).unwrap(); sleep(Duration::new(0, 1_000_000)); @@ -148,9 +144,9 @@ mod tests { sleep(Duration::new(0, 1_000_000)); tx_sender.send(Signal::Tick).unwrap(); - let entry0 = record_stage.entry_receiver.recv().unwrap()[0].clone(); - let entry1 = record_stage.entry_receiver.recv().unwrap()[0].clone(); - let entry2 = record_stage.entry_receiver.recv().unwrap()[0].clone(); + let entry0 = entry_receiver.recv().unwrap()[0].clone(); + let entry1 = entry_receiver.recv().unwrap()[0].clone(); + let entry2 = entry_receiver.recv().unwrap()[0].clone(); assert_eq!(entry0.num_hashes, 0); assert_eq!(entry1.num_hashes, 0); @@ -166,8 +162,8 @@ mod tests { fn test_historian_closed_sender() { let (tx_sender, tx_receiver) = channel(); let zero = Hash::default(); - let record_stage = RecordStage::new(tx_receiver, &zero); - drop(record_stage.entry_receiver); + let (record_stage, entry_receiver) = RecordStage::new(tx_receiver, &zero); + drop(entry_receiver); tx_sender.send(Signal::Tick).unwrap(); assert_eq!(record_stage.thread_hdl.join().unwrap(), ()); } @@ -176,7 +172,7 @@ mod tests { fn test_transactions() { let (tx_sender, signal_receiver) = channel(); let zero = Hash::default(); - let record_stage = RecordStage::new(signal_receiver, &zero); + let (_record_stage, entry_receiver) = RecordStage::new(signal_receiver, &zero); let alice_keypair = KeyPair::new(); let bob_pubkey = KeyPair::new().pubkey(); let tx0 = Transaction::new(&alice_keypair, bob_pubkey, 1, zero); @@ -185,7 +181,7 @@ mod tests { .send(Signal::Transactions(vec![tx0, tx1])) .unwrap(); drop(tx_sender); - let entries: Vec<_> = record_stage.entry_receiver.iter().collect(); + let entries: Vec<_> = entry_receiver.iter().collect(); assert_eq!(entries.len(), 1); } @@ -193,12 +189,12 @@ mod tests { fn test_clock() { let (tx_sender, tx_receiver) = channel(); let zero = Hash::default(); - let record_stage = + let (_record_stage, entry_receiver) = RecordStage::new_with_clock(tx_receiver, &zero, Duration::from_millis(20)); sleep(Duration::from_millis(900)); tx_sender.send(Signal::Tick).unwrap(); drop(tx_sender); - let entries: Vec<_> = record_stage.entry_receiver.iter().flat_map(|x| x).collect(); + let entries: Vec<_> = entry_receiver.iter().flat_map(|x| x).collect(); assert!(entries.len() > 1); // Ensure the ID is not the seed. diff --git a/src/request_stage.rs b/src/request_stage.rs index 878cc97fc6..c047a6da22 100644 --- a/src/request_stage.rs +++ b/src/request_stage.rs @@ -17,7 +17,6 @@ use timing; pub struct RequestStage { pub thread_hdl: JoinHandle<()>, - pub blob_receiver: BlobReceiver, pub request_processor: Arc, } @@ -85,7 +84,7 @@ impl RequestStage { packet_receiver: Receiver, packet_recycler: PacketRecycler, blob_recycler: BlobRecycler, - ) -> Self { + ) -> (Self, BlobReceiver) { let request_processor = Arc::new(request_processor); let request_processor_ = request_processor.clone(); let (blob_sender, blob_receiver) = channel(); @@ -106,10 +105,12 @@ impl RequestStage { } }) .unwrap(); - RequestStage { - thread_hdl, + ( + RequestStage { + thread_hdl, + request_processor, + }, blob_receiver, - request_processor, - } + ) } } diff --git a/src/rpu.rs b/src/rpu.rs index 507b91fea7..03f876226a 100644 --- a/src/rpu.rs +++ b/src/rpu.rs @@ -56,7 +56,7 @@ impl Rpu { let blob_recycler = BlobRecycler::default(); let request_processor = RequestProcessor::new(bank.clone()); - let request_stage = RequestStage::new( + let (request_stage, blob_receiver) = RequestStage::new( request_processor, exit.clone(), packet_receiver, @@ -68,7 +68,7 @@ impl Rpu { respond_socket, exit.clone(), blob_recycler.clone(), - request_stage.blob_receiver, + blob_receiver, ); let thread_hdls = vec![t_receiver, t_responder, request_stage.thread_hdl]; diff --git a/src/server.rs b/src/server.rs index dc6b9881ca..612f4be561 100644 --- a/src/server.rs +++ b/src/server.rs @@ -63,7 +63,7 @@ impl Server { thread_hdls.extend(rpu.thread_hdls); let blob_recycler = BlobRecycler::default(); - let tpu = Tpu::new( + let (tpu, blob_receiver) = Tpu::new( bank.clone(), tick_duration, transactions_socket, @@ -92,7 +92,7 @@ impl Server { window, entry_height, blob_recycler.clone(), - tpu.blob_receiver, + blob_receiver, ); thread_hdls.extend(vec![t_broadcast]); diff --git a/src/sigverify_stage.rs b/src/sigverify_stage.rs index 563b9dd146..7e2a74cb43 100644 --- a/src/sigverify_stage.rs +++ b/src/sigverify_stage.rs @@ -18,18 +18,17 @@ use streamer::{self, PacketReceiver}; use timing; pub struct SigVerifyStage { - pub verified_receiver: Receiver)>>, pub thread_hdls: Vec>, } impl SigVerifyStage { - pub fn new(exit: Arc, packet_receiver: Receiver) -> Self { + pub fn new( + exit: Arc, + packet_receiver: Receiver, + ) -> (Self, Receiver)>>) { let (verified_sender, verified_receiver) = channel(); let thread_hdls = Self::verifier_services(exit, packet_receiver, verified_sender); - SigVerifyStage { - thread_hdls, - verified_receiver, - } + (SigVerifyStage { thread_hdls }, verified_receiver) } fn verify_batch(batch: Vec) -> Vec<(SharedPackets, Vec)> { diff --git a/src/tpu.rs b/src/tpu.rs index 6db538d1f9..0d1d35e11d 100644 --- a/src/tpu.rs +++ b/src/tpu.rs @@ -41,7 +41,6 @@ use streamer::BlobReceiver; use write_stage::WriteStage; pub struct Tpu { - pub blob_receiver: BlobReceiver, pub thread_hdls: Vec>, } @@ -53,36 +52,35 @@ impl Tpu { blob_recycler: BlobRecycler, exit: Arc, writer: W, - ) -> Self { + ) -> (Self, BlobReceiver) { let packet_recycler = PacketRecycler::default(); - let fetch_stage = + let (fetch_stage, packet_receiver) = FetchStage::new(transactions_socket, exit.clone(), packet_recycler.clone()); - let sigverify_stage = SigVerifyStage::new(exit.clone(), fetch_stage.packet_receiver); + let (sigverify_stage, verified_receiver) = + SigVerifyStage::new(exit.clone(), packet_receiver); - let banking_stage = BankingStage::new( + let (banking_stage, signal_receiver) = BankingStage::new( bank.clone(), exit.clone(), - sigverify_stage.verified_receiver, + verified_receiver, packet_recycler.clone(), ); - let record_stage = match tick_duration { - Some(tick_duration) => RecordStage::new_with_clock( - banking_stage.signal_receiver, - &bank.last_id(), - tick_duration, - ), - None => RecordStage::new(banking_stage.signal_receiver, &bank.last_id()), + let (record_stage, entry_receiver) = match tick_duration { + Some(tick_duration) => { + RecordStage::new_with_clock(signal_receiver, &bank.last_id(), tick_duration) + } + None => RecordStage::new(signal_receiver, &bank.last_id()), }; - let write_stage = WriteStage::new( + let (write_stage, blob_receiver) = WriteStage::new( bank.clone(), exit.clone(), blob_recycler.clone(), writer, - record_stage.entry_receiver, + entry_receiver, ); let mut thread_hdls = vec![ banking_stage.thread_hdl, @@ -91,9 +89,6 @@ impl Tpu { ]; thread_hdls.extend(fetch_stage.thread_hdls.into_iter()); thread_hdls.extend(sigverify_stage.thread_hdls.into_iter()); - Tpu { - blob_receiver: write_stage.blob_receiver, - thread_hdls, - } + (Tpu { thread_hdls }, blob_receiver) } } diff --git a/src/tvu.rs b/src/tvu.rs index 025a8bfa47..a7aec358ba 100644 --- a/src/tvu.rs +++ b/src/tvu.rs @@ -73,7 +73,7 @@ impl Tvu { exit: Arc, ) -> Self { let blob_recycler = BlobRecycler::default(); - let fetch_stage = BlobFetchStage::new_multi_socket( + let (fetch_stage, blob_receiver) = BlobFetchStage::new_multi_socket( vec![replicate_socket, repair_socket], exit.clone(), blob_recycler.clone(), @@ -81,17 +81,17 @@ impl Tvu { //TODO //the packets coming out of blob_receiver need to be sent to the GPU and verified //then sent to the window, which does the erasure coding reconstruction - let window_stage = WindowStage::new( + let (window_stage, blob_receiver) = WindowStage::new( crdt, window, entry_height, retransmit_socket, exit.clone(), blob_recycler.clone(), - fetch_stage.blob_receiver, + blob_receiver, ); - let replicate_stage = ReplicateStage::new(bank, exit, window_stage.blob_receiver); + let replicate_stage = ReplicateStage::new(bank, exit, blob_receiver); let mut threads = vec![replicate_stage.thread_hdl]; threads.extend(fetch_stage.thread_hdls.into_iter()); diff --git a/src/window_stage.rs b/src/window_stage.rs index 46f6eb2b98..9529f540e8 100644 --- a/src/window_stage.rs +++ b/src/window_stage.rs @@ -10,7 +10,6 @@ use std::thread::JoinHandle; use streamer::{self, BlobReceiver, Window}; pub struct WindowStage { - pub blob_receiver: BlobReceiver, pub thread_hdls: Vec>, } @@ -23,7 +22,7 @@ impl WindowStage { exit: Arc, blob_recycler: BlobRecycler, fetch_stage_receiver: BlobReceiver, - ) -> Self { + ) -> (Self, BlobReceiver) { let (retransmit_sender, retransmit_receiver) = channel(); let t_retransmit = streamer::retransmitter( @@ -46,9 +45,6 @@ impl WindowStage { ); let thread_hdls = vec![t_retransmit, t_window]; - WindowStage { - blob_receiver, - thread_hdls, - } + (WindowStage { thread_hdls }, blob_receiver) } } diff --git a/src/write_stage.rs b/src/write_stage.rs index ecdf773d3c..f5ef96256a 100644 --- a/src/write_stage.rs +++ b/src/write_stage.rs @@ -19,7 +19,6 @@ use streamer::{BlobReceiver, BlobSender}; pub struct WriteStage { pub thread_hdl: JoinHandle<()>, - pub blob_receiver: BlobReceiver, } impl WriteStage { @@ -50,7 +49,7 @@ impl WriteStage { blob_recycler: BlobRecycler, writer: W, entry_receiver: Receiver>, - ) -> Self { + ) -> (Self, BlobReceiver) { let (blob_sender, blob_receiver) = channel(); let thread_hdl = Builder::new() .name("solana-writer".to_string()) @@ -71,9 +70,6 @@ impl WriteStage { }) .unwrap(); - WriteStage { - thread_hdl, - blob_receiver, - } + (WriteStage { thread_hdl }, blob_receiver) } }