diff --git a/src/accountant_skel.rs b/src/accountant_skel.rs index 5ef7f5471c..d1b4b3a77e 100644 --- a/src/accountant_skel.rs +++ b/src/accountant_skel.rs @@ -1,4 +1,3 @@ -use std::io; use accountant::Accountant; use transaction::Transaction; use signature::PublicKey; @@ -6,6 +5,13 @@ use hash::Hash; use entry::Entry; use std::net::UdpSocket; use bincode::{deserialize, serialize}; +use result::Result; +use streamer; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use std::sync::mpsc::channel; +use std::thread::{spawn, JoinHandle}; +use std::default::Default; pub struct AccountantSkel { pub acc: Accountant, @@ -55,21 +61,78 @@ impl AccountantSkel { }), } } + fn process( + &mut self, + r_reader: &streamer::Receiver, + s_sender: &streamer::Sender, + recycler: streamer::Recycler, + ) -> Result<()> { + let timer = Duration::new(1, 0); + let msgs = r_reader.recv_timeout(timer)?; + let msgs_ = msgs.clone(); + let msgs__ = msgs.clone(); + let rsps = streamer::allocate(recycler.clone()); + let rsps_ = rsps.clone(); + let l = msgs__.read().unwrap().packets.len(); + rsps.write() + .unwrap() + .packets + .resize(l, streamer::Packet::default()); + { + let mut num = 0; + let mut ursps = rsps.write().unwrap(); + for packet in msgs.read().unwrap().packets.iter() { + let sz = packet.size; + let req = deserialize(&packet.data[0..sz])?; + if let Some(resp) = self.process_request(req) { + let rsp = ursps.packets.get_mut(num).unwrap(); + let v = serialize(&resp)?; + let len = v.len(); + rsp.data[0..len].copy_from_slice(&v); + rsp.size = len; + rsp.set_addr(&packet.get_addr()); + num += 1; + } + } + ursps.packets.resize(num, streamer::Packet::default()); + } + s_sender.send(rsps_)?; + streamer::recycle(recycler, msgs_); + Ok(()) + } /// UDP Server that forwards messages to Accountant methods. - pub fn serve(self: &mut Self, addr: &str) -> io::Result<()> { - let socket = UdpSocket::bind(addr)?; - let mut buf = vec![0u8; 1024]; - loop { - //println!("skel: Waiting for incoming packets..."); - let (_sz, src) = socket.recv_from(&mut buf)?; + pub fn serve( + obj: Arc>, + addr: &str, + exit: Arc>, + ) -> Result<[Arc>; 3]> { + let read = UdpSocket::bind(addr)?; + // make sure we are on the same interface + let mut local = read.local_addr()?; + local.set_port(0); + let write = UdpSocket::bind(local)?; - // TODO: Return a descriptive error message if deserialization fails. - let req = deserialize(&buf).expect("deserialize request"); + let recycler = Arc::new(Mutex::new(Vec::new())); + let (s_reader, r_reader) = channel(); + let t_receiver = streamer::receiver(read, exit.clone(), recycler.clone(), s_reader)?; - if let Some(resp) = self.process_request(req) { - socket.send_to(&serialize(&resp).expect("serialize response"), &src)?; - } - } + let (s_sender, r_sender) = channel(); + let t_sender = streamer::sender(write, exit.clone(), recycler.clone(), r_sender); + + let t_server = spawn(move || { + match Arc::try_unwrap(obj) { + Ok(me) => loop { + let e = me.lock() + .unwrap() + .process(&r_reader, &s_sender, recycler.clone()); + if e.is_err() && *exit.lock().unwrap() { + break; + } + }, + _ => (), + }; + }); + Ok([Arc::new(t_receiver), Arc::new(t_sender), Arc::new(t_server)]) } } diff --git a/src/accountant_stub.rs b/src/accountant_stub.rs index cc8740d455..73177b0450 100644 --- a/src/accountant_stub.rs +++ b/src/accountant_stub.rs @@ -115,10 +115,11 @@ mod tests { use super::*; use accountant::Accountant; use accountant_skel::AccountantSkel; - use std::thread::{sleep, spawn}; + use std::thread::sleep; use std::time::Duration; use mint::Mint; use signature::{KeyPair, KeyPairUtil}; + use std::sync::{Arc, Mutex}; #[test] fn test_accountant_stub() { @@ -127,7 +128,9 @@ mod tests { let alice = Mint::new(10_000); let acc = Accountant::new(&alice, None); let bob_pubkey = KeyPair::new().pubkey(); - spawn(move || AccountantSkel::new(acc).serve(addr).unwrap()); + let exit = Arc::new(Mutex::new(false)); + let acc = Arc::new(Mutex::new(AccountantSkel::new(acc))); + let threads = AccountantSkel::serve(acc, addr, exit.clone()).unwrap(); sleep(Duration::from_millis(30)); let socket = UdpSocket::bind(send_addr).unwrap(); @@ -137,5 +140,12 @@ mod tests { .unwrap(); acc.wait_on_signature(&sig).unwrap(); assert_eq!(acc.get_balance(&bob_pubkey).unwrap().unwrap(), 500); + *exit.lock().unwrap() = true; + for t in threads.iter() { + match Arc::try_unwrap((*t).clone()) { + Ok(j) => j.join().expect("join"), + _ => (), + } + } } } diff --git a/src/bin/testnode.rs b/src/bin/testnode.rs index 22a73e6eb4..4fcc5ef48c 100644 --- a/src/bin/testnode.rs +++ b/src/bin/testnode.rs @@ -4,6 +4,7 @@ extern crate silk; use silk::accountant_skel::AccountantSkel; use silk::accountant::Accountant; use std::io::{self, BufRead}; +use std::sync::{Arc, Mutex}; fn main() { let addr = "127.0.0.1:8000"; @@ -13,7 +14,8 @@ fn main() { .lines() .map(|line| serde_json::from_str(&line.unwrap()).unwrap()); let acc = Accountant::new_from_entries(entries, Some(1000)); - let mut skel = AccountantSkel::new(acc); + let exit = Arc::new(Mutex::new(false)); + let skel = Arc::new(Mutex::new(AccountantSkel::new(acc))); eprintln!("Listening on {}", addr); - skel.serve(addr).unwrap(); + let _threads = AccountantSkel::serve(skel, addr, exit.clone()).unwrap(); } diff --git a/src/result.rs b/src/result.rs index 36257647ca..86f2118c83 100644 --- a/src/result.rs +++ b/src/result.rs @@ -1,6 +1,7 @@ use serde_json; use std; use std::any::Any; +use bincode; #[derive(Debug)] pub enum Error { @@ -10,6 +11,7 @@ pub enum Error { JoinError(Box), RecvError(std::sync::mpsc::RecvError), RecvTimeoutError(std::sync::mpsc::RecvTimeoutError), + Serialize(std::boxed::Box), SendError, } @@ -51,6 +53,11 @@ impl std::convert::From for Error { Error::AddrParse(e) } } +impl std::convert::From> for Error { + fn from(e: std::boxed::Box) -> Error { + Error::Serialize(e) + } +} #[cfg(test)] mod tests { diff --git a/src/streamer.rs b/src/streamer.rs index 21541858bf..48d8db74cf 100644 --- a/src/streamer.rs +++ b/src/streamer.rs @@ -1,12 +1,12 @@ use std::sync::{Arc, Mutex, RwLock}; -use std::sync::mpsc::{Receiver, Sender}; +use std::sync::mpsc; use std::time::Duration; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::thread::{spawn, JoinHandle}; use result::{Error, Result}; const BLOCK_SIZE: usize = 1024 * 8; -const PACKET_SIZE: usize = 256; +pub const PACKET_SIZE: usize = 256; #[derive(Clone)] pub struct Packet { @@ -76,8 +76,10 @@ pub struct PacketData { pub packets: Vec, } -type SharedPacketData = Arc>; -type Recycler = Arc>>; +pub type SharedPacketData = Arc>; +pub type Recycler = Arc>>; +pub type Receiver = mpsc::Receiver; +pub type Sender = mpsc::Sender; impl PacketData { pub fn new() -> PacketData { @@ -92,6 +94,7 @@ impl PacketData { p.size = 0; match socket.recv_from(&mut p.data) { Err(_) if i > 0 => { + trace!("got {:?} messages", i); break; } Err(e) => { @@ -126,13 +129,13 @@ impl PacketData { } } -fn allocate(recycler: Recycler) -> SharedPacketData { +pub fn allocate(recycler: Recycler) -> SharedPacketData { let mut gc = recycler.lock().expect("lock"); gc.pop() .unwrap_or_else(|| Arc::new(RwLock::new(PacketData::new()))) } -fn recycle(recycler: Recycler, msgs: SharedPacketData) { +pub fn recycle(recycler: Recycler, msgs: SharedPacketData) { let mut gc = recycler.lock().expect("lock"); gc.push(msgs); } @@ -141,7 +144,7 @@ fn recv_loop( sock: &UdpSocket, exit: Arc>, recycler: Recycler, - channel: Sender, + channel: Sender, ) -> Result<()> { loop { let msgs = allocate(recycler.clone()); @@ -167,7 +170,7 @@ pub fn receiver( sock: UdpSocket, exit: Arc>, recycler: Recycler, - channel: Sender, + channel: Sender, ) -> Result> { let timer = Duration::new(1, 0); sock.set_read_timeout(Some(timer))?; @@ -177,7 +180,7 @@ pub fn receiver( })) } -fn recv_send(sock: &UdpSocket, recycler: Recycler, r: &Receiver) -> Result<()> { +fn recv_send(sock: &UdpSocket, recycler: Recycler, r: &Receiver) -> Result<()> { let timer = Duration::new(1, 0); let msgs = r.recv_timeout(timer)?; let msgs_ = msgs.clone(); @@ -191,7 +194,7 @@ pub fn sender( sock: UdpSocket, exit: Arc>, recycler: Recycler, - r: Receiver, + r: Receiver, ) -> JoinHandle<()> { spawn(move || loop { if recv_send(&sock, recycler.clone(), &r).is_err() && *exit.lock().unwrap() { @@ -208,10 +211,9 @@ mod test { use std::time::Duration; use std::time::SystemTime; use std::thread::{spawn, JoinHandle}; - use std::sync::mpsc::{channel, Receiver}; + use std::sync::mpsc::channel; use result::Result; - use streamer::{allocate, receiver, recycle, sender, Packet, Recycler, SharedPacketData, - PACKET_SIZE}; + use streamer::{allocate, receiver, recycle, sender, Packet, Receiver, Recycler, PACKET_SIZE}; fn producer(addr: &SocketAddr, recycler: Recycler, exit: Arc>) -> JoinHandle<()> { let send = UdpSocket::bind("0.0.0.0:0").unwrap(); @@ -235,7 +237,7 @@ mod test { recycler: Recycler, exit: Arc>, rvs: Arc>, - r: Receiver, + r: Receiver, ) -> JoinHandle<()> { spawn(move || loop { if *exit.lock().unwrap() { @@ -289,9 +291,8 @@ mod test { run_streamer_bench().unwrap(); } - fn get_msgs(r: Receiver, num: &mut usize) { - let mut tries = 0; - loop { + fn get_msgs(r: Receiver, num: &mut usize) { + for _ in [0..5].iter() { let timer = Duration::new(1, 0); match r.recv_timeout(timer) { Ok(m) => *num += m.read().unwrap().packets.len(), @@ -300,10 +301,6 @@ mod test { if *num == 10 { break; } - if tries == 5 { - break; - } - tries += 1; } } #[cfg(ipv6)]