diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index 1ecde06dd3..f6f445e4ff 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -3404,8 +3404,7 @@ mod tests { .unwrap(); let mut packets = vec![Packet::default(); 2]; - let (_, num_received) = - recv_mmsg(recv_socket, &mut packets[..]).unwrap_or_default(); + let num_received = recv_mmsg(recv_socket, &mut packets[..]).unwrap_or_default(); assert_eq!(num_received, expected_num_forwarded, "{}", name); } @@ -3504,8 +3503,7 @@ mod tests { .unwrap(); let mut packets = vec![Packet::default(); 2]; - let (_, num_received) = - recv_mmsg(recv_socket, &mut packets[..]).unwrap_or_default(); + let num_received = recv_mmsg(recv_socket, &mut packets[..]).unwrap_or_default(); assert_eq!(num_received, expected_ids.len(), "{}", name); for (i, expected_id) in expected_ids.iter().enumerate() { assert_eq!(packets[i].meta.size, 1); diff --git a/streamer/src/packet.rs b/streamer/src/packet.rs index b0abe551a3..34404143f1 100644 --- a/streamer/src/packet.rs +++ b/streamer/src/packet.rs @@ -41,7 +41,7 @@ pub fn recv_from(batch: &mut PacketBatch, socket: &UdpSocket, max_wait_ms: u64) trace!("recv_from err {:?}", e); return Err(e); } - Ok((_, npkts)) => { + Ok(npkts) => { if i == 0 { socket.set_nonblocking(true)?; } @@ -112,6 +112,10 @@ mod tests { } send_to(&batch, &send_socket, &SocketAddrSpace::Unspecified).unwrap(); + batch + .packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); let recvd = recv_from(&mut batch, &recv_socket, 1).unwrap(); assert_eq!(recvd, batch.packets.len()); diff --git a/streamer/src/recvmmsg.rs b/streamer/src/recvmmsg.rs index 28337d8485..2a7a07bcff 100644 --- a/streamer/src/recvmmsg.rs +++ b/streamer/src/recvmmsg.rs @@ -2,7 +2,7 @@ pub use solana_perf::packet::NUM_RCVMMSGS; use { - crate::packet::Packet, + crate::packet::{Meta, Packet}, std::{cmp, io, net::UdpSocket}, }; #[cfg(target_os = "linux")] @@ -10,14 +10,14 @@ use { itertools::izip, libc::{iovec, mmsghdr, sockaddr_storage, socklen_t, AF_INET, AF_INET6, MSG_WAITFORONE}, nix::sys::socket::InetAddr, - std::{mem, os::unix::io::AsRawFd}, + std::{convert::TryFrom, mem, os::unix::io::AsRawFd}, }; #[cfg(not(target_os = "linux"))] -pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result<(usize, usize)> { +pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result { + debug_assert!(packets.iter().all(|pkt| pkt.meta == Meta::default())); let mut i = 0; let count = cmp::min(NUM_RCVMMSGS, packets.len()); - let mut total_size = 0; for p in packets.iter_mut().take(count) { p.meta.size = 0; match socket.recv_from(&mut p.data) { @@ -28,7 +28,6 @@ pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result<(usiz return Err(e); } Ok((nrecv, from)) => { - total_size += nrecv; p.meta.size = nrecv; p.meta.set_addr(&from); if i == 0 { @@ -38,7 +37,7 @@ pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result<(usiz } i += 1; } - Ok((total_size, i)) + Ok(i) } #[cfg(target_os = "linux")] @@ -67,7 +66,9 @@ fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option #[cfg(target_os = "linux")] #[allow(clippy::uninit_assumed_init)] -pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result<(usize, usize)> { +pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result { + // Assert that there are no leftovers in packets. + debug_assert!(packets.iter().all(|pkt| pkt.meta == Meta::default())); const SOCKADDR_STORAGE_SIZE: usize = mem::size_of::(); let mut hdrs: [mmsghdr; NUM_RCVMMSGS] = unsafe { mem::zeroed() }; @@ -95,25 +96,18 @@ pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result<(usize, }; let nrecv = unsafe { libc::recvmmsg(sock_fd, &mut hdrs[0], count as u32, MSG_WAITFORONE, &mut ts) }; - if nrecv < 0 { + let nrecv = if nrecv < 0 { return Err(io::Error::last_os_error()); + } else { + usize::try_from(nrecv).unwrap() + }; + for (addr, hdr, pkt) in izip!(&addrs, &hdrs, packets.iter_mut()).take(nrecv) { + pkt.meta.size = hdr.msg_len as usize; + if let Some(addr) = cast_socket_addr(addr, hdr) { + pkt.meta.set_addr(&addr.to_std()); + } } - let mut npkts = 0; - let mut total_size = 0; - - izip!(&addrs, &hdrs, packets.iter_mut()) - .take(nrecv as usize) - .filter_map(|(addr, hdr, pkt)| { - let addr = cast_socket_addr(addr, hdr)?.to_std(); - Some((addr, hdr, pkt)) - }) - .for_each(|(addr, hdr, pkt)| { - pkt.meta.size = hdr.msg_len as usize; - pkt.meta.set_addr(&addr); - npkts += 1; - total_size += pkt.meta.size; - }); - Ok((total_size, npkts)) + Ok(nrecv) } #[cfg(test)] @@ -147,7 +141,7 @@ mod tests { } let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(sent, recv); for packet in packets.iter().take(recv) { assert_eq!(packet.meta.size, PACKET_DATA_SIZE); @@ -173,14 +167,17 @@ mod tests { } let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(TEST_NUM_MSGS, recv); for packet in packets.iter().take(recv) { assert_eq!(packet.meta.size, PACKET_DATA_SIZE); assert_eq!(packet.meta.addr(), saddr); } - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(sent - TEST_NUM_MSGS, recv); for packet in packets.iter().take(recv) { assert_eq!(packet.meta.size, PACKET_DATA_SIZE); @@ -212,7 +209,7 @@ mod tests { let start = Instant::now(); let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(TEST_NUM_MSGS, recv); for packet in packets.iter().take(recv) { assert_eq!(packet.meta.size, PACKET_DATA_SIZE); @@ -220,6 +217,9 @@ mod tests { } reader.set_nonblocking(true).unwrap(); + packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); let _recv = recv_mmsg(&reader, &mut packets[..]); assert!(start.elapsed().as_secs() < 5); } @@ -249,7 +249,7 @@ mod tests { let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(TEST_NUM_MSGS, recv); for packet in packets.iter().take(sent1) { assert_eq!(packet.meta.size, PACKET_DATA_SIZE); @@ -260,7 +260,10 @@ mod tests { assert_eq!(packet.meta.addr(), saddr2); } - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(sent1 + sent2 - TEST_NUM_MSGS, recv); for packet in packets.iter().take(recv) { assert_eq!(packet.meta.size, PACKET_DATA_SIZE); diff --git a/streamer/src/sendmmsg.rs b/streamer/src/sendmmsg.rs index 6e434f5ed5..47abcc0af8 100644 --- a/streamer/src/sendmmsg.rs +++ b/streamer/src/sendmmsg.rs @@ -175,7 +175,7 @@ mod tests { assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(32, recv); } @@ -206,11 +206,11 @@ mod tests { assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(16, recv); let mut packets = vec![Packet::default(); 32]; - let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap(); assert_eq!(16, recv); } @@ -241,19 +241,19 @@ mod tests { assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); assert_eq!(1, recv); let mut packets = vec![Packet::default(); 32]; - let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap(); assert_eq!(1, recv); let mut packets = vec![Packet::default(); 32]; - let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap(); assert_eq!(1, recv); let mut packets = vec![Packet::default(); 32]; - let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap(); assert_eq!(1, recv); } diff --git a/streamer/tests/recvmmsg.rs b/streamer/tests/recvmmsg.rs index 614dc16ee1..7fa18739df 100644 --- a/streamer/tests/recvmmsg.rs +++ b/streamer/tests/recvmmsg.rs @@ -2,7 +2,7 @@ use { solana_streamer::{ - packet::{Packet, PACKET_DATA_SIZE}, + packet::{Meta, Packet, PACKET_DATA_SIZE}, recvmmsg::*, }, std::{net::UdpSocket, time::Instant}, @@ -25,7 +25,7 @@ pub fn test_recv_mmsg_batch_size() { } let mut packets = vec![Packet::default(); TEST_BATCH_SIZE]; let now = Instant::now(); - let recv = recv_mmsg(&reader, &mut packets[..]).unwrap().1; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); elapsed_in_max_batch += now.elapsed().as_nanos(); assert_eq!(TEST_BATCH_SIZE, recv); }); @@ -40,10 +40,13 @@ pub fn test_recv_mmsg_batch_size() { let mut recv = 0; let now = Instant::now(); while let Ok(num) = recv_mmsg(&reader, &mut packets[..]) { - recv += num.1; + recv += num; if recv >= TEST_BATCH_SIZE { break; } + packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); } elapsed_in_small_batch += now.elapsed().as_nanos(); assert_eq!(TEST_BATCH_SIZE, recv);