From d69016720176697d16b6b7b083ee8fdd813b3ee4 Mon Sep 17 00:00:00 2001 From: zonyitoo Date: Sun, 8 May 2022 17:09:17 +0800 Subject: [PATCH] basic wrappers for sendmmsg/recvmmsg (#823) --- .../src/net/sys/unix/bsd/freebsd.rs | 166 ++++++++++++++++- .../shadowsocks/src/net/sys/unix/bsd/macos.rs | 175 +++++++++++++++++- .../shadowsocks/src/net/sys/unix/linux/mod.rs | 166 ++++++++++++++++- crates/shadowsocks/src/net/udp.rs | 103 ++++++++++- .../src/relay/udprelay/proxy_socket.rs | 111 +++++------ 5 files changed, 645 insertions(+), 76 deletions(-) diff --git a/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs b/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs index 245c8bc94d45..c66d9b289022 100644 --- a/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs +++ b/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs @@ -4,12 +4,14 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, os::unix::io::{AsRawFd, RawFd}, pin::Pin, + ptr, + sync::atomic::{AtomicBool, Ordering}, task::{self, Poll}, }; -use log::{error, warn}; +use log::{debug, error, warn}; use pin_project::pin_project; -use socket2::{Domain, Protocol, Socket, Type}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::{TcpSocket, TcpStream as TokioTcpStream, UdpSocket}, @@ -18,6 +20,7 @@ use tokio_tfo::TfoStream; use crate::net::{ sys::{set_common_sockopt_after_connect, set_common_sockopt_for_connect, socket_bind_dual_stack}, + udp::{BatchRecvMessage, BatchSendMessage}, AddrFamily, ConnectOpts, }; @@ -241,3 +244,162 @@ pub async fn create_outbound_udp_socket(af: AddrFamily, config: &ConnectOpts) -> Ok(socket) } + +static SUPPORT_BATCH_SEND_RECV_MSG: AtomicBool = AtomicBool::new(true); + +fn recvmsg_fallback(sock: &S, msg: &mut BatchRecvMessage<'_>) -> io::Result<()> { + let mut hdr: libc::msghdr = unsafe { mem::zeroed() }; + + let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t; + let sock_addr = unsafe { SockAddr::new(addr_storage, addr_len) }; + hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_namelen = sock_addr.len() as _; + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + let ret = unsafe { libc::recvmsg(sock.as_raw_fd(), &mut hdr as *mut _, 0) }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + msg.addr = sock_addr.as_socket().expect("SockAddr.as_socket"); + msg.data_len = ret as usize; + + Ok(()) +} + +pub fn batch_recvmsg(sock: &S, msgs: &mut [BatchRecvMessage<'_>]) -> io::Result { + if msgs.is_empty() { + return Ok(0); + } + + if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Acquire) { + recvmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + + let mut vec_msg_name = Vec::with_capacity(msgs.len()); + let mut vec_msg_hdr = Vec::with_capacity(msgs.len()); + + for msg in msgs.iter_mut() { + let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() }; + + let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t; + + vec_msg_name.push(unsafe { SockAddr::new(addr_storage, addr_len) }); + let sock_addr = vec_msg_name.last_mut().unwrap(); + hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_hdr.msg_namelen = sock_addr.len() as _; + + hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_hdr.msg_iovlen = msg.data.len() as _; + + vec_msg_hdr.push(hdr); + } + + let ret = unsafe { + libc::recvmmsg( + sock.as_raw_fd(), + vec_msg_hdr.as_mut_ptr(), + vec_msg_hdr.len() as _, + 0, + ptr::null(), + ) + }; + if ret < 0 { + let err = io::Error::last_os_error(); + if let Some(libc::ENOSYS) = err.raw_os_error() { + debug!("recvmmsg is not supported, fallback to recvmsg, error: {:?}", err); + SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Release); + + recvmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + return Err(err); + } + + for idx in 0..ret as usize { + let msg = &mut msgs[idx]; + let hdr = &vec_msg_hdr[idx]; + let name = &vec_msg_name[idx]; + msg.addr = name.as_socket().expect("SockAddr.as_socket"); + msg.data_len = hdr.msg_len as usize; + } + + Ok(ret as usize) +} + +fn sendmsg_fallback(sock: &S, msg: &mut BatchSendMessage<'_>) -> io::Result<()> { + let mut hdr: libc::msghdr = unsafe { mem::zeroed() }; + + let sock_addr = msg.addr.map(SockAddr::from); + if let Some(ref sa) = sock_addr { + hdr.msg_name = sa.as_ptr() as *mut _; + hdr.msg_namelen = sa.len() as _; + } + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + let ret = unsafe { libc::sendmsg(sock.as_raw_fd(), &hdr as *const _, 0) }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + msg.data_len = ret as usize; + + Ok(()) +} + +pub fn batch_sendmsg(sock: &S, msgs: &mut [BatchSendMessage<'_>]) -> io::Result { + if msgs.is_empty() { + return Ok(0); + } + + if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Acquire) { + sendmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + + let mut vec_msg_name = Vec::with_capacity(msgs.len()); + let mut vec_msg_hdr = Vec::with_capacity(msgs.len()); + + for msg in msgs.iter_mut() { + let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() }; + + if let Some(addr) = msg.addr { + vec_msg_name.push(SockAddr::from(addr)); + let sock_addr = vec_msg_name.last_mut().unwrap(); + hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_hdr.msg_namelen = sock_addr.len() as _; + } + + hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_hdr.msg_iovlen = msg.data.len() as _; + + vec_msg_hdr.push(hdr); + } + + let ret = unsafe { libc::sendmmsg(sock.as_raw_fd(), vec_msg_hdr.as_mut_ptr(), vec_msg_hdr.len() as _, 0) }; + if ret < 0 { + let err = io::Error::last_os_error(); + if let Some(libc::ENOSYS) = err.raw_os_error() { + debug!("sendmmsg is not supported, fallback to sendmsg, error: {:?}", err); + SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Release); + + sendmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + return Err(err); + } + + for idx in 0..ret as usize { + let msg = &mut msgs[idx]; + let hdr = &vec_msg_hdr[idx]; + msg.data_len = hdr.msg_len as usize; + } + + Ok(ret as usize) +} diff --git a/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs b/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs index 36ad3099f4a1..3871a5e5eecc 100644 --- a/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs +++ b/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs @@ -5,12 +5,13 @@ use std::{ os::unix::io::{AsRawFd, RawFd}, pin::Pin, ptr, + sync::atomic::{AtomicBool, Ordering}, task::{self, Poll}, }; -use log::{error, warn}; +use log::{debug, error, warn}; use pin_project::pin_project; -use socket2::{Domain, Protocol, Socket, Type}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::{TcpSocket, TcpStream as TokioTcpStream, UdpSocket}, @@ -19,6 +20,7 @@ use tokio_tfo::TfoStream; use crate::net::{ sys::{set_common_sockopt_after_connect, set_common_sockopt_for_connect, socket_bind_dual_stack}, + udp::{BatchRecvMessage, BatchSendMessage}, AddrFamily, ConnectOpts, }; @@ -273,3 +275,172 @@ pub async fn create_outbound_udp_socket(af: AddrFamily, config: &ConnectOpts) -> Ok(socket) } + +/// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/socket.h +#[repr(C)] +struct msghdr_x { + msg_name: *mut libc::c_void, //< optional address + msg_namelen: libc::socklen_t, //< size of address + msg_iov: *mut libc::iovec, //< scatter/gather array + msg_iovlen: libc::c_int, //< # elements in msg_iov + msg_control: *mut libc::c_void, //< ancillary data, see below + msg_controllen: libc::socklen_t, //< ancillary data buffer len + msg_flags: libc::c_int, //< flags on received message + msg_datalen: libc::size_t, //< byte length of buffer in msg_iov +} + +extern "C" { + fn recvmsg_x(s: libc::c_int, msgp: *const msghdr_x, cnt: libc::c_uint, flags: libc::c_int) -> libc::ssize_t; + fn sendmsg_x(s: libc::c_int, msgp: *const msghdr_x, cnt: libc::c_uint, flags: libc::c_int) -> libc::ssize_t; +} + +static SUPPORT_BATCH_SEND_RECV_MSG: AtomicBool = AtomicBool::new(true); + +fn recvmsg_fallback(sock: &S, msg: &mut BatchRecvMessage<'_>) -> io::Result<()> { + let mut hdr: libc::msghdr = unsafe { mem::zeroed() }; + + let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t; + let sock_addr = unsafe { SockAddr::new(addr_storage, addr_len) }; + hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_namelen = sock_addr.len() as _; + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + let ret = unsafe { libc::recvmsg(sock.as_raw_fd(), &mut hdr as *mut _, 0) }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + msg.addr = sock_addr.as_socket().expect("SockAddr.as_socket"); + msg.data_len = ret as usize; + + Ok(()) +} + +pub fn batch_recvmsg(sock: &S, msgs: &mut [BatchRecvMessage<'_>]) -> io::Result { + if msgs.is_empty() { + return Ok(0); + } + + if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Acquire) { + recvmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + + let mut vec_msg_name = Vec::with_capacity(msgs.len()); + let mut vec_msg_hdr = Vec::with_capacity(msgs.len()); + + for msg in msgs.iter_mut() { + let mut hdr: msghdr_x = unsafe { mem::zeroed() }; + + let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t; + + vec_msg_name.push(unsafe { SockAddr::new(addr_storage, addr_len) }); + let sock_addr = vec_msg_name.last_mut().unwrap(); + hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_namelen = sock_addr.len() as _; + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + vec_msg_hdr.push(hdr); + } + + let ret = unsafe { recvmsg_x(sock.as_raw_fd(), vec_msg_hdr.as_ptr(), vec_msg_hdr.len() as _, 0) }; + if ret < 0 { + let err = io::Error::last_os_error(); + if let Some(libc::ENOSYS) = err.raw_os_error() { + debug!("recvmsg_x is not supported, fallback to recvmsg, error: {:?}", err); + SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Release); + + recvmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + return Err(err); + } + + for idx in 0..ret as usize { + let msg = &mut msgs[idx]; + let hdr = &vec_msg_hdr[idx]; + let name = &vec_msg_name[idx]; + msg.addr = name.as_socket().expect("SockAddr.as_socket"); + msg.data_len = hdr.msg_datalen as usize; + } + + Ok(ret as usize) +} + +fn sendmsg_fallback(sock: &S, msg: &mut BatchSendMessage<'_>) -> io::Result<()> { + let mut hdr: libc::msghdr = unsafe { mem::zeroed() }; + + let sock_addr = msg.addr.map(SockAddr::from); + if let Some(ref sa) = sock_addr { + hdr.msg_name = sa.as_ptr() as *mut _; + hdr.msg_namelen = sa.len() as _; + } + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + let ret = unsafe { libc::sendmsg(sock.as_raw_fd(), &hdr as *const _, 0) }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + msg.data_len = ret as usize; + + Ok(()) +} + +pub fn batch_sendmsg(sock: &S, msgs: &mut [BatchSendMessage<'_>]) -> io::Result { + if msgs.is_empty() { + return Ok(0); + } + + if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Acquire) { + sendmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + + let mut vec_msg_name = Vec::with_capacity(msgs.len()); + let mut vec_msg_hdr = Vec::with_capacity(msgs.len()); + + for msg in msgs.iter_mut() { + let mut hdr: msghdr_x = unsafe { mem::zeroed() }; + + if let Some(addr) = msg.addr { + vec_msg_name.push(SockAddr::from(addr)); + let sock_addr = vec_msg_name.last_mut().unwrap(); + hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_namelen = sock_addr.len() as _; + } + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + vec_msg_hdr.push(hdr); + } + + let ret = unsafe { sendmsg_x(sock.as_raw_fd(), vec_msg_hdr.as_ptr(), vec_msg_hdr.len() as _, 0) }; + if ret < 0 { + let err = io::Error::last_os_error(); + if let Some(libc::ENOSYS) = err.raw_os_error() { + debug!("sendmsg_x is not supported, fallback to sendmsg, error: {:?}", err); + SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Release); + + sendmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + return Err(err); + } + + for idx in 0..ret as usize { + let msg = &mut msgs[idx]; + let hdr = &vec_msg_hdr[idx]; + msg.data_len = hdr.msg_datalen as usize; + } + + Ok(ret as usize) +} diff --git a/crates/shadowsocks/src/net/sys/unix/linux/mod.rs b/crates/shadowsocks/src/net/sys/unix/linux/mod.rs index 4c6969510968..f2601d0eb0ac 100644 --- a/crates/shadowsocks/src/net/sys/unix/linux/mod.rs +++ b/crates/shadowsocks/src/net/sys/unix/linux/mod.rs @@ -4,13 +4,15 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, os::unix::io::{AsRawFd, RawFd}, pin::Pin, + ptr, + sync::atomic::{AtomicBool, Ordering}, task::{self, Poll}, }; use cfg_if::cfg_if; -use log::{error, warn}; +use log::{debug, error, warn}; use pin_project::pin_project; -use socket2::{Domain, Protocol, Socket, Type}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::{TcpSocket, TcpStream as TokioTcpStream, UdpSocket}, @@ -19,6 +21,7 @@ use tokio_tfo::TfoStream; use crate::net::{ sys::{set_common_sockopt_after_connect, set_common_sockopt_for_connect, socket_bind_dual_stack}, + udp::{BatchRecvMessage, BatchSendMessage}, AddrFamily, ConnectOpts, }; @@ -363,3 +366,162 @@ cfg_if! { } } } + +static SUPPORT_BATCH_SEND_RECV_MSG: AtomicBool = AtomicBool::new(true); + +fn recvmsg_fallback(sock: &S, msg: &mut BatchRecvMessage<'_>) -> io::Result<()> { + let mut hdr: libc::msghdr = unsafe { mem::zeroed() }; + + let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t; + let sock_addr = unsafe { SockAddr::new(addr_storage, addr_len) }; + hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_namelen = sock_addr.len() as _; + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + let ret = unsafe { libc::recvmsg(sock.as_raw_fd(), &mut hdr as *mut _, 0) }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + msg.addr = sock_addr.as_socket().expect("SockAddr.as_socket"); + msg.data_len = ret as usize; + + Ok(()) +} + +pub fn batch_recvmsg(sock: &S, msgs: &mut [BatchRecvMessage<'_>]) -> io::Result { + if msgs.is_empty() { + return Ok(0); + } + + if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Acquire) { + recvmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + + let mut vec_msg_name = Vec::with_capacity(msgs.len()); + let mut vec_msg_hdr = Vec::with_capacity(msgs.len()); + + for msg in msgs.iter_mut() { + let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() }; + + let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t; + + vec_msg_name.push(unsafe { SockAddr::new(addr_storage, addr_len) }); + let sock_addr = vec_msg_name.last_mut().unwrap(); + hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_hdr.msg_namelen = sock_addr.len() as _; + + hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_hdr.msg_iovlen = msg.data.len() as _; + + vec_msg_hdr.push(hdr); + } + + let ret = unsafe { + libc::recvmmsg( + sock.as_raw_fd(), + vec_msg_hdr.as_mut_ptr(), + vec_msg_hdr.len() as _, + 0, + ptr::null_mut(), + ) + }; + if ret < 0 { + let err = io::Error::last_os_error(); + if let Some(libc::ENOSYS) = err.raw_os_error() { + debug!("recvmmsg is not supported, fallback to recvmsg, error: {:?}", err); + SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Release); + + recvmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + return Err(err); + } + + for idx in 0..ret as usize { + let msg = &mut msgs[idx]; + let hdr = &vec_msg_hdr[idx]; + let name = &vec_msg_name[idx]; + msg.addr = name.as_socket().expect("SockAddr.as_socket"); + msg.data_len = hdr.msg_len as usize; + } + + Ok(ret as usize) +} + +fn sendmsg_fallback(sock: &S, msg: &mut BatchSendMessage<'_>) -> io::Result<()> { + let mut hdr: libc::msghdr = unsafe { mem::zeroed() }; + + let sock_addr = msg.addr.map(SockAddr::from); + if let Some(ref sa) = sock_addr { + hdr.msg_name = sa.as_ptr() as *mut _; + hdr.msg_namelen = sa.len() as _; + } + + hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_iovlen = msg.data.len() as _; + + let ret = unsafe { libc::sendmsg(sock.as_raw_fd(), &hdr as *const _, 0) }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + msg.data_len = ret as usize; + + Ok(()) +} + +pub fn batch_sendmsg(sock: &S, msgs: &mut [BatchSendMessage<'_>]) -> io::Result { + if msgs.is_empty() { + return Ok(0); + } + + if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Acquire) { + sendmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + + let mut vec_msg_name = Vec::with_capacity(msgs.len()); + let mut vec_msg_hdr = Vec::with_capacity(msgs.len()); + + for msg in msgs.iter_mut() { + let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() }; + + if let Some(addr) = msg.addr { + vec_msg_name.push(SockAddr::from(addr)); + let sock_addr = vec_msg_name.last_mut().unwrap(); + hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _; + hdr.msg_hdr.msg_namelen = sock_addr.len() as _; + } + + hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _; + hdr.msg_hdr.msg_iovlen = msg.data.len() as _; + + vec_msg_hdr.push(hdr); + } + + let ret = unsafe { libc::sendmmsg(sock.as_raw_fd(), vec_msg_hdr.as_mut_ptr(), vec_msg_hdr.len() as _, 0) }; + if ret < 0 { + let err = io::Error::last_os_error(); + if let Some(libc::ENOSYS) = err.raw_os_error() { + debug!("sendmmsg is not supported, fallback to sendmsg, error: {:?}", err); + SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Release); + + sendmsg_fallback(sock, &mut msgs[0])?; + return Ok(1); + } + return Err(err); + } + + for idx in 0..ret as usize { + let msg = &mut msgs[idx]; + let hdr = &vec_msg_hdr[idx]; + msg.data_len = hdr.msg_len as usize; + } + + Ok(ret as usize) +} diff --git a/crates/shadowsocks/src/net/udp.rs b/crates/shadowsocks/src/net/udp.rs index 195b81c13f28..c49065ce64a0 100644 --- a/crates/shadowsocks/src/net/udp.rs +++ b/crates/shadowsocks/src/net/udp.rs @@ -1,12 +1,15 @@ //! UDP socket wrappers use std::{ - io, + io::{self, ErrorKind, IoSlice, IoSliceMut}, net::SocketAddr, ops::{Deref, DerefMut}, + task::{Context as TaskContext, Poll}, }; +use futures::{future, ready}; use pin_project::pin_project; +use tokio::io::Interest; use crate::{context::Context, relay::socks5::Address, ServerAddr}; @@ -17,6 +20,32 @@ use super::{ ConnectOpts, }; +/// Message struct for `batch_send` +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "macos", + target_os = "freebsd" +))] +pub struct BatchSendMessage<'a> { + pub addr: Option, + pub data: &'a [IoSlice<'a>], + pub data_len: usize, +} + +/// Message struct for `batch_recv` +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "macos", + target_os = "freebsd" +))] +pub struct BatchRecvMessage<'a> { + pub addr: SocketAddr, + pub data: &'a mut [IoSliceMut<'a>], + pub data_len: usize, +} + /// Wrappers for outbound `UdpSocket` #[pin_project] pub struct UdpSocket(#[pin] tokio::net::UdpSocket); @@ -93,6 +122,78 @@ impl UdpSocket { pub async fn connect_any_with_opts>(af: AF, opts: &ConnectOpts) -> io::Result { create_outbound_udp_socket(af.into(), opts).await.map(UdpSocket) } + + /// Batch send packets + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "macos", + target_os = "freebsd" + ))] + pub fn poll_batch_send( + &self, + cx: &mut TaskContext<'_>, + msgs: &mut [BatchSendMessage<'_>], + ) -> Poll> { + use super::sys::batch_sendmsg; + + loop { + ready!(self.0.poll_send_ready(cx))?; + + match self.0.try_io(Interest::WRITABLE, || batch_sendmsg(&self.0, msgs)) { + Ok(n) => return Ok(n).into(), + Err(ref err) if err.kind() == ErrorKind::WouldBlock => {} + Err(err) => return Err(err).into(), + } + } + } + + /// Batch send packets + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "macos", + target_os = "freebsd" + ))] + pub async fn batch_send(&self, msgs: &mut [BatchSendMessage<'_>]) -> io::Result { + future::poll_fn(|cx| self.poll_batch_send(cx, msgs)).await + } + + /// Batch recv packets + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "macos", + target_os = "freebsd" + ))] + pub fn poll_batch_recv( + &self, + cx: &mut TaskContext<'_>, + msgs: &mut [BatchRecvMessage<'_>], + ) -> Poll> { + use super::sys::batch_recvmsg; + + loop { + ready!(self.0.poll_recv_ready(cx))?; + + match self.0.try_io(Interest::READABLE, || batch_recvmsg(&self.0, msgs)) { + Ok(n) => return Ok(n).into(), + Err(ref err) if err.kind() == ErrorKind::WouldBlock => {} + Err(err) => return Err(err).into(), + } + } + } + + /// Batch recv packets + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "macos", + target_os = "freebsd" + ))] + pub async fn batch_recv(&self, msgs: &mut [BatchRecvMessage<'_>]) -> io::Result { + future::poll_fn(|cx| self.poll_batch_recv(cx, msgs)).await + } } impl Deref for UdpSocket { diff --git a/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs b/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs index c155e8d666b4..114b3077d2b0 100644 --- a/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs +++ b/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs @@ -5,10 +5,7 @@ use std::{io, net::SocketAddr, time::Duration}; use bytes::BytesMut; use log::{trace, warn}; use once_cell::sync::Lazy; -use tokio::{ - net::{ToSocketAddrs, UdpSocket}, - time, -}; +use tokio::{net::ToSocketAddrs, time}; use crate::{ config::{ServerAddr, ServerConfig}, @@ -40,7 +37,7 @@ pub enum UdpSocketType { /// UDP client for communicating with ShadowSocks' server pub struct ProxySocket { socket_type: UdpSocketType, - socket: UdpSocket, + socket: ShadowUdpSocket, method: CipherKind, key: Box<[u8]>, send_timeout: Option, @@ -70,17 +67,20 @@ impl ProxySocket { UdpSocketType::Client, context, svr_cfg, - socket.into(), + socket, )) } /// Create a `ProxySocket` from a `UdpSocket` - pub fn from_socket( + pub fn from_socket( socket_type: UdpSocketType, context: SharedContext, svr_cfg: &ServerConfig, - socket: UdpSocket, - ) -> ProxySocket { + socket: S, + ) -> ProxySocket + where + S: Into, + { let key = svr_cfg.key().to_vec().into_boxed_slice(); let method = svr_cfg.method(); @@ -88,7 +88,7 @@ impl ProxySocket { ProxySocket { socket_type, - socket, + socket: socket.into(), method, key, send_timeout: None, @@ -122,10 +122,27 @@ impl ProxySocket { UdpSocketType::Server, context, svr_cfg, - socket.into(), + socket, )) } + fn encrypt_send_buffer( + &self, + addr: &Address, + control: &UdpSocketControlData, + payload: &[u8], + send_buf: &mut BytesMut, + ) { + match self.socket_type { + UdpSocketType::Client => { + encrypt_client_payload(&self.context, self.method, &self.key, addr, control, payload, send_buf) + } + UdpSocketType::Server => { + encrypt_server_payload(&self.context, self.method, &self.key, addr, control, payload, send_buf) + } + } + } + /// Send a UDP packet to addr through proxy #[inline] pub async fn send(&self, addr: &Address, payload: &[u8]) -> io::Result { @@ -140,27 +157,7 @@ impl ProxySocket { payload: &[u8], ) -> io::Result { let mut send_buf = BytesMut::new(); - - match self.socket_type { - UdpSocketType::Client => encrypt_client_payload( - &self.context, - self.method, - &self.key, - addr, - control, - payload, - &mut send_buf, - ), - UdpSocketType::Server => encrypt_server_payload( - &self.context, - self.method, - &self.key, - addr, - control, - payload, - &mut send_buf, - ), - } + self.encrypt_send_buffer(addr, control, payload, &mut send_buf); trace!( "UDP server client send to {}, control: {:?}, payload length {} bytes, packet length {} bytes", @@ -205,27 +202,7 @@ impl ProxySocket { payload: &[u8], ) -> io::Result { let mut send_buf = BytesMut::new(); - - match self.socket_type { - UdpSocketType::Client => encrypt_client_payload( - &self.context, - self.method, - &self.key, - addr, - control, - payload, - &mut send_buf, - ), - UdpSocketType::Server => encrypt_server_payload( - &self.context, - self.method, - &self.key, - addr, - control, - payload, - &mut send_buf, - ), - } + self.encrypt_send_buffer(addr, control, payload, &mut send_buf); trace!( "UDP server client send to, addr {}, control: {:?}, payload length {} bytes, packet length {} bytes", @@ -255,6 +232,16 @@ impl ProxySocket { Ok(send_len) } + async fn decrypt_recv_buffer( + &self, + recv_buf: &mut [u8], + ) -> io::Result<(usize, Address, Option)> { + match self.socket_type { + UdpSocketType::Client => decrypt_server_payload(&self.context, self.method, &self.key, recv_buf).await, + UdpSocketType::Server => decrypt_client_payload(&self.context, self.method, &self.key, recv_buf).await, + } + } + /// Receive packet from Shadowsocks' UDP server /// /// This function will use `recv_buf` to store intermediate data, so it has to be big enough to store the whole shadowsocks' packet @@ -283,14 +270,7 @@ impl ProxySocket { }, }; - let (n, addr, control) = match self.socket_type { - UdpSocketType::Client => { - decrypt_server_payload(&self.context, self.method, &self.key, &mut recv_buf[..recv_n]).await? - } - UdpSocketType::Server => { - decrypt_client_payload(&self.context, self.method, &self.key, &mut recv_buf[..recv_n]).await? - } - }; + let (n, addr, control) = self.decrypt_recv_buffer(&mut recv_buf[..recv_n]).await?; trace!( "UDP server client receive from {}, control: {:?}, packet length {} bytes, payload length {} bytes", @@ -333,14 +313,7 @@ impl ProxySocket { }, }; - let (n, addr, control) = match self.socket_type { - UdpSocketType::Client => { - decrypt_server_payload(&self.context, self.method, &self.key, &mut recv_buf[..recv_n]).await? - } - UdpSocketType::Server => { - decrypt_client_payload(&self.context, self.method, &self.key, &mut recv_buf[..recv_n]).await? - } - }; + let (n, addr, control) = self.decrypt_recv_buffer(&mut recv_buf[..recv_n]).await?; trace!( "UDP server client receive from {}, addr {}, control: {:?}, packet length {} bytes, payload length {} bytes",