diff --git a/crates/shadowsocks/src/relay/socks5.rs b/crates/shadowsocks/src/relay/socks5.rs index da4beb9f30f9..032ba1873cd2 100644 --- a/crates/shadowsocks/src/relay/socks5.rs +++ b/crates/shadowsocks/src/relay/socks5.rs @@ -12,7 +12,7 @@ use std::{ vec, }; -use bytes::{BufMut, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub use self::consts::{ @@ -213,6 +213,47 @@ pub enum Address { } impl Address { + /// read from a cursor + pub fn read_cursor>(cur: &mut io::Cursor) -> Result { + if cur.remaining() < 2 { + return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into()); + } + + let atyp = cur.get_u8(); + match atyp { + consts::SOCKS5_ADDR_TYPE_IPV4 => { + if cur.remaining() < 4 + 2 { + return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into()); + } + let addr = Ipv4Addr::from(cur.get_u32()); + let port = cur.get_u16(); + Ok(Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(addr, port)))) + } + consts::SOCKS5_ADDR_TYPE_IPV6 => { + if cur.remaining() < 16 + 2 { + return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into()); + } + let addr = Ipv6Addr::from(cur.get_u128()); + let port = cur.get_u16(); + Ok(Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new( + addr, port, 0, 0, + )))) + } + consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => { + let domain_len = cur.get_u8() as usize; + if cur.remaining() < domain_len { + return Err(Error::AddressDomainInvalidEncoding); + } + let mut buf = vec![0u8; domain_len]; + cur.copy_to_slice(&mut buf); + let port = cur.get_u16(); + let addr = String::from_utf8(buf).map_err(|_| Error::AddressDomainInvalidEncoding)?; + Ok(Address::DomainNameAddress(addr, port)) + } + _ => Err(Error::AddressTypeNotSupported(atyp)), + } + } + /// Parse from a `AsyncRead` pub async fn read_from(stream: &mut R) -> Result where diff --git a/crates/shadowsocks/src/relay/udprelay/aead.rs b/crates/shadowsocks/src/relay/udprelay/aead.rs index 24e7e61af0be..4455b4471b58 100644 --- a/crates/shadowsocks/src/relay/udprelay/aead.rs +++ b/crates/shadowsocks/src/relay/udprelay/aead.rs @@ -77,7 +77,7 @@ pub fn encrypt_payload_aead( } /// Decrypt UDP AEAD protocol packet -pub async fn decrypt_payload_aead( +pub fn decrypt_payload_aead( _context: &Context, method: CipherKind, key: &[u8], @@ -109,7 +109,7 @@ pub async fn decrypt_payload_aead( let data_len = data.len() - tag_len; let data = &mut data[..data_len]; - let (dn, addr) = parse_packet(data).await?; + let (dn, addr) = parse_packet(data)?; let data_length = data_len - dn; let data_start_idx = salt_len + dn; @@ -120,9 +120,10 @@ pub async fn decrypt_payload_aead( Ok((data_length, addr)) } -async fn parse_packet(buf: &[u8]) -> ProtocolResult<(usize, Address)> { +#[inline] +fn parse_packet(buf: &[u8]) -> ProtocolResult<(usize, Address)> { let mut cur = Cursor::new(buf); - match Address::read_from(&mut cur).await { + match Address::read_cursor(&mut cur) { Ok(address) => { let pos = cur.position() as usize; Ok((pos, address)) diff --git a/crates/shadowsocks/src/relay/udprelay/aead_2022.rs b/crates/shadowsocks/src/relay/udprelay/aead_2022.rs index 26b0c5f738c3..4cf967fbeb87 100644 --- a/crates/shadowsocks/src/relay/udprelay/aead_2022.rs +++ b/crates/shadowsocks/src/relay/udprelay/aead_2022.rs @@ -529,7 +529,7 @@ pub fn encrypt_client_payload_aead_2022( } /// Decrypt `Client -> Server` UDP AEAD protocol packet -pub async fn decrypt_client_payload_aead_2022( +pub fn decrypt_client_payload_aead_2022( context: &Context, method: CipherKind, key: &[u8], @@ -581,7 +581,7 @@ pub async fn decrypt_client_payload_aead_2022( user, }; - let addr = match Address::read_from(&mut cursor).await { + let addr = match Address::read_cursor(&mut cursor) { Ok(a) => a, Err(err) => return Err(ProtocolError::InvalidAddress(err)), }; @@ -641,7 +641,7 @@ pub fn encrypt_server_payload_aead_2022( } /// Decrypt `Server -> Client` UDP AEAD protocol packet -pub async fn decrypt_server_payload_aead_2022( +pub fn decrypt_server_payload_aead_2022( context: &Context, method: CipherKind, key: &[u8], @@ -687,7 +687,7 @@ pub async fn decrypt_server_payload_aead_2022( user: None, }; - let addr = match Address::read_from(&mut cursor).await { + let addr = match Address::read_cursor(&mut cursor) { Ok(a) => a, Err(err) => return Err(ProtocolError::InvalidAddress(err)), }; diff --git a/crates/shadowsocks/src/relay/udprelay/crypto_io.rs b/crates/shadowsocks/src/relay/udprelay/crypto_io.rs index 87d43a1aa204..e0c082565c31 100644 --- a/crates/shadowsocks/src/relay/udprelay/crypto_io.rs +++ b/crates/shadowsocks/src/relay/udprelay/crypto_io.rs @@ -132,7 +132,7 @@ pub fn encrypt_server_payload( } /// Decrypt `Client -> Server` payload from ShadowSocks UDP encrypted packet -pub async fn decrypt_client_payload( +pub fn decrypt_client_payload( context: &Context, method: CipherKind, key: &[u8], @@ -143,7 +143,7 @@ pub async fn decrypt_client_payload( CipherCategory::None => { let _ = user_manager; let mut cur = Cursor::new(payload); - match Address::read_from(&mut cur).await { + match Address::read_cursor(&mut cur) { Ok(address) => { let pos = cur.position() as usize; let payload = cur.into_inner(); @@ -157,27 +157,24 @@ pub async fn decrypt_client_payload( CipherCategory::Stream => { let _ = user_manager; decrypt_payload_stream(context, method, key, payload) - .await .map(|(n, a)| (n, a, None)) .map_err(Into::into) } CipherCategory::Aead => { let _ = user_manager; decrypt_payload_aead(context, method, key, payload) - .await .map(|(n, a)| (n, a, None)) .map_err(Into::into) } #[cfg(feature = "aead-cipher-2022")] CipherCategory::Aead2022 => decrypt_client_payload_aead_2022(context, method, key, payload, user_manager) - .await .map(|(n, a, c)| (n, a, Some(c))) .map_err(Into::into), } } /// Decrypt `Server -> Client` payload from ShadowSocks UDP encrypted packet -pub async fn decrypt_server_payload( +pub fn decrypt_server_payload( context: &Context, method: CipherKind, key: &[u8], @@ -186,7 +183,7 @@ pub async fn decrypt_server_payload( match method.category() { CipherCategory::None => { let mut cur = Cursor::new(payload); - match Address::read_from(&mut cur).await { + match Address::read_cursor(&mut cur) { Ok(address) => { let pos = cur.position() as usize; let payload = cur.into_inner(); @@ -198,16 +195,13 @@ pub async fn decrypt_server_payload( } #[cfg(feature = "stream-cipher")] CipherCategory::Stream => decrypt_payload_stream(context, method, key, payload) - .await .map(|(n, a)| (n, a, None)) .map_err(Into::into), CipherCategory::Aead => decrypt_payload_aead(context, method, key, payload) - .await .map(|(n, a)| (n, a, None)) .map_err(Into::into), #[cfg(feature = "aead-cipher-2022")] CipherCategory::Aead2022 => decrypt_server_payload_aead_2022(context, method, key, payload) - .await .map(|(n, a, c)| (n, a, Some(c))) .map_err(Into::into), } diff --git a/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs b/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs index 2de221353e49..fb2a9700221f 100644 --- a/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs +++ b/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs @@ -4,14 +4,15 @@ use std::{ io::{self, ErrorKind}, net::SocketAddr, sync::Arc, + task::{ready, Context, Poll}, time::Duration, }; use byte_string::ByteStr; use bytes::{Bytes, BytesMut}; -use log::{trace, warn}; +use log::{info, trace, warn}; use once_cell::sync::Lazy; -use tokio::{net::ToSocketAddrs, time}; +use tokio::{io::ReadBuf, net::ToSocketAddrs, time}; use crate::{ config::{ServerAddr, ServerConfig, ServerUserManager}, @@ -254,6 +255,95 @@ impl ProxySocket { Ok(send_len) } + /// poll family functions + /// the send_timeout is ignored. + pub fn poll_send(&self, addr: &Address, payload: &[u8], cx: &mut Context<'_>) -> Poll> { + self.poll_send_with_ctrl(addr, &DEFAULT_SOCKET_CONTROL, payload, cx) + } + + pub fn poll_send_with_ctrl( + &self, + addr: &Address, + control: &UdpSocketControlData, + payload: &[u8], + cx: &mut Context<'_>, + ) -> Poll> { + let mut send_buf = BytesMut::with_capacity(payload.len() + 256); + + self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?; + + trace!( + "UDP server client send to {}, control: {:?}, payload length {} bytes, packet length {} bytes", + addr, + control, + payload.len(), + send_buf.len() + ); + + let n_send_buf = send_buf.len(); + + match self.socket.poll_send(cx, &mut send_buf.freeze()).map_err(|x| x.into()) { + Poll::Ready(Ok(l)) => { + if l == n_send_buf { + Poll::Ready(Ok(payload.len())) + } else { + Poll::Ready(Err(io::Error::from(ErrorKind::WriteZero).into())) + } + } + x => x, + } + } + + pub fn poll_send_to( + &self, + target: SocketAddr, + addr: &Address, + payload: &[u8], + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_send_to_with_ctrl(target, addr, &DEFAULT_SOCKET_CONTROL, payload, cx) + } + + pub fn poll_send_to_with_ctrl( + &self, + target: SocketAddr, + addr: &Address, + control: &UdpSocketControlData, + payload: &[u8], + cx: &mut Context<'_>, + ) -> Poll> { + let mut send_buf = BytesMut::with_capacity(payload.len() + 256); + + self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?; + + info!( + "UDP server client send to {}, payload length {} bytes, packet length {} bytes", + target, + payload.len(), + send_buf.len() + ); + + let n_send_buf = send_buf.len(); + match self + .socket + .poll_send_to(cx, &mut send_buf.freeze(), target) + .map_err(|x| x.into()) + { + Poll::Ready(Ok(l)) => { + if l == n_send_buf { + Poll::Ready(Ok(payload.len())) + } else { + Poll::Ready(Err(io::Error::from(ErrorKind::WriteZero).into())) + } + } + x => x, + } + } + + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.socket.poll_send_ready(cx).map_err(|x| x.into()) + } + /// Send a UDP packet to target from proxy pub async fn send_to( &self, @@ -305,15 +395,15 @@ impl ProxySocket { Ok(send_len) } - async fn decrypt_recv_buffer( + fn decrypt_recv_buffer( &self, recv_buf: &mut [u8], user_manager: Option<&ServerUserManager>, ) -> ProtocolResult<(usize, Address, Option)> { match self.socket_type { - UdpSocketType::Client => decrypt_server_payload(&self.context, self.method, &self.key, recv_buf).await, + UdpSocketType::Client => decrypt_server_payload(&self.context, self.method, &self.key, recv_buf), UdpSocketType::Server => { - decrypt_client_payload(&self.context, self.method, &self.key, recv_buf, user_manager).await + decrypt_client_payload(&self.context, self.method, &self.key, recv_buf, user_manager) } } } @@ -346,10 +436,7 @@ impl ProxySocket { }, }; - let (n, addr, control) = match self - .decrypt_recv_buffer(&mut recv_buf[..recv_n], self.user_manager.as_deref()) - .await - { + let (n, addr, control) = match self.decrypt_recv_buffer(&mut recv_buf[..recv_n], self.user_manager.as_deref()) { Ok(x) => x, Err(err) => return Err(ProxySocketError::ProtocolError(err)), }; @@ -395,10 +482,7 @@ impl ProxySocket { }, }; - let (n, addr, control) = match self - .decrypt_recv_buffer(&mut recv_buf[..recv_n], self.user_manager.as_deref()) - .await - { + let (n, addr, control) = match self.decrypt_recv_buffer(&mut recv_buf[..recv_n], self.user_manager.as_deref()) { Ok(x) => x, Err(err) => return Err(ProxySocketError::ProtocolErrorWithPeer(target_addr, err)), }; @@ -415,6 +499,60 @@ impl ProxySocket { Ok((n, target_addr, addr, recv_n, control)) } + /// poll family functions. + /// the recv_timeout is ignored. + pub fn poll_recv( + &self, + cx: &mut Context<'_>, + recv_buf: &mut ReadBuf, + ) -> Poll> { + self.poll_recv_with_ctrl(cx, recv_buf) + .map(|r| r.map(|(n, a, rn, _)| (n, a, rn))) + } + + /// poll family functions + pub fn poll_recv_with_ctrl( + &self, + cx: &mut Context<'_>, + recv_buf: &mut ReadBuf, + ) -> Poll)>> { + ready!(self.socket.poll_recv(cx, recv_buf))?; + + let n_recv = recv_buf.filled().len(); + + match self.decrypt_recv_buffer(recv_buf.filled_mut(), self.user_manager.as_deref()) { + Ok(x) => Poll::Ready(Ok((x.0, x.1, n_recv, x.2))), + Err(err) => return Poll::Ready(Err(ProxySocketError::ProtocolError(err))), + } + } + + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + recv_buf: &mut ReadBuf, + ) -> Poll> { + self.poll_recv_from_with_ctrl(cx, recv_buf) + .map(|r| r.map(|(n, sa, a, rn, _)| (n, sa, a, rn))) + } + + pub fn poll_recv_from_with_ctrl( + &self, + cx: &mut Context<'_>, + recv_buf: &mut ReadBuf, + ) -> Poll)>> { + let src = ready!(self.socket.poll_recv_from(cx, recv_buf))?; + + let n_recv = recv_buf.filled().len(); + match self.decrypt_recv_buffer(recv_buf.filled_mut(), self.user_manager.as_deref()) { + Ok(x) => Poll::Ready(Ok((x.0, src, x.1, n_recv, x.2))), + Err(err) => return Poll::Ready(Err(ProxySocketError::ProtocolError(err))), + } + } + + pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.socket.poll_recv_ready(cx).map_err(|x| x.into()) + } + /// Get local addr of socket pub fn local_addr(&self) -> io::Result { self.socket.local_addr() diff --git a/crates/shadowsocks/src/relay/udprelay/stream.rs b/crates/shadowsocks/src/relay/udprelay/stream.rs index ef8730b66e0f..d6706c7954b4 100644 --- a/crates/shadowsocks/src/relay/udprelay/stream.rs +++ b/crates/shadowsocks/src/relay/udprelay/stream.rs @@ -66,7 +66,7 @@ pub fn encrypt_payload_stream( } /// Decrypt UDP stream protocol packet -pub async fn decrypt_payload_stream( +pub fn decrypt_payload_stream( _context: &Context, method: CipherKind, key: &[u8], @@ -87,7 +87,7 @@ pub async fn decrypt_payload_stream( assert!(cipher.decrypt_packet(data)); - let (dn, addr) = parse_packet(data).await?; + let (dn, addr) = parse_packet(data)?; let data_start_idx = iv_len + dn; let data_length = payload.len() - data_start_idx; @@ -96,9 +96,10 @@ pub async fn decrypt_payload_stream( Ok((data_length, addr)) } -async fn parse_packet(buf: &[u8]) -> ProtocolResult<(usize, Address)> { +#[inline] +fn parse_packet(buf: &[u8]) -> ProtocolResult<(usize, Address)> { let mut cur = Cursor::new(buf); - match Address::read_from(&mut cur).await { + match Address::read_cursor(&mut cur) { Ok(address) => { let pos = cur.position() as usize; Ok((pos, address))