diff --git a/neqo-client/Cargo.toml b/neqo-client/Cargo.toml index 08fe2f8fcd..e9b48221aa 100644 --- a/neqo-client/Cargo.toml +++ b/neqo-client/Cargo.toml @@ -15,7 +15,7 @@ clap = { version = "4.4", features = ["derive"] } futures = "0.3" hex = "0.4" log = { version = "0.4", default-features = false } -neqo-common = { path = "./../neqo-common" } +neqo-common = { path = "./../neqo-common", features = ["udp"] } neqo-crypto = { path = "./../neqo-crypto" } neqo-http3 = { path = "./../neqo-http3" } neqo-qpack = { path = "./../neqo-qpack" } diff --git a/neqo-client/src/main.rs b/neqo-client/src/main.rs index 3aede9e545..7584697a47 100644 --- a/neqo-client/src/main.rs +++ b/neqo-client/src/main.rs @@ -22,13 +22,12 @@ use std::{ }; use clap::Parser; -use common::IpTos; use futures::{ future::{select, Either}, FutureExt, TryFutureExt, }; use neqo_common::{ - self as common, event::Provider, hex, qdebug, qinfo, qlog::NeqoQlog, Datagram, Role, + self as common, event::Provider, hex, qdebug, qinfo, qlog::NeqoQlog, udp, Datagram, Role, }; use neqo_crypto::{ constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256}, @@ -42,7 +41,7 @@ use neqo_transport::{ EmptyConnectionIdGenerator, Error as TransportError, StreamId, StreamType, Version, }; use qlog::{events::EventImportance, streamer::QlogStreamer}; -use tokio::{net::UdpSocket, time::Sleep}; +use tokio::time::Sleep; use url::{Origin, Url}; #[derive(Debug)] @@ -351,21 +350,6 @@ impl QuicParameters { } } -async fn emit_datagram(socket: &UdpSocket, out_dgram: Datagram) -> Result<(), io::Error> { - let sent = match socket.send_to(&out_dgram, &out_dgram.destination()).await { - Ok(res) => res, - Err(ref err) if err.kind() != io::ErrorKind::WouldBlock => { - eprintln!("UDP send error: {err:?}"); - 0 - } - Err(e) => return Err(e), - }; - if sent != out_dgram.len() { - eprintln!("Unable to send all {} bytes of datagram", out_dgram.len()); - } - Ok(()) -} - fn get_output_file( url: &Url, output_dir: &Option, @@ -415,7 +399,7 @@ enum Ready { // Wait for the socket to be readable or the timeout to fire. async fn ready( - socket: &UdpSocket, + socket: &udp::Socket, mut timeout: Option<&mut Pin>>, ) -> Result { let socket_ready = Box::pin(socket.readable()).map_ok(|()| Ready::Socket); @@ -426,43 +410,6 @@ async fn ready( select(socket_ready, timeout_ready).await.factor_first().0 } -fn read_dgram( - socket: &UdpSocket, - local_address: &SocketAddr, -) -> Result, io::Error> { - let buf = &mut [0u8; 2048]; - let (sz, remote_addr) = match socket.try_recv_from(&mut buf[..]) { - Err(ref err) - if err.kind() == io::ErrorKind::WouldBlock - || err.kind() == io::ErrorKind::Interrupted => - { - return Ok(None) - } - Err(err) => { - eprintln!("UDP recv error: {err:?}"); - return Err(err); - } - Ok(res) => res, - }; - - if sz == buf.len() { - eprintln!("Might have received more than {} bytes", buf.len()); - } - - if sz == 0 { - eprintln!("zero length datagram received?"); - Ok(None) - } else { - Ok(Some(Datagram::new( - remote_addr, - *local_address, - IpTos::default(), - None, - &buf[..sz], - ))) - } -} - trait StreamHandler { fn process_header_ready(&mut self, stream_id: StreamId, fin: bool, headers: Vec
); fn process_data_readable( @@ -817,7 +764,7 @@ fn to_headers(values: &[impl AsRef]) -> Vec
{ struct ClientRunner<'a> { local_addr: SocketAddr, - socket: &'a UdpSocket, + socket: &'a udp::Socket, client: Http3Client, handler: Handler<'a>, timeout: Option>>, @@ -827,7 +774,7 @@ struct ClientRunner<'a> { impl<'a> ClientRunner<'a> { fn new( args: &'a mut Args, - socket: &'a UdpSocket, + socket: &'a udp::Socket, local_addr: SocketAddr, remote_addr: SocketAddr, hostname: &str, @@ -880,7 +827,7 @@ impl<'a> ClientRunner<'a> { match ready(self.socket, self.timeout.as_mut()).await? { Ready::Socket => loop { - let dgram = read_dgram(self.socket, &self.local_addr)?; + let dgram = self.socket.recv(&self.local_addr)?; if dgram.is_none() { break; } @@ -915,7 +862,8 @@ impl<'a> ClientRunner<'a> { loop { match self.client.process(dgram.take(), Instant::now()) { Output::Datagram(dgram) => { - emit_datagram(self.socket, dgram).await?; + self.socket.writable().await?; + self.socket.send(dgram)?; } Output::Callback(new_timeout) => { qinfo!("Setting timeout of {:?}", new_timeout); @@ -1051,16 +999,7 @@ async fn main() -> Res<()> { SocketAddr::V6(..) => SocketAddr::new(IpAddr::V6(Ipv6Addr::from([0; 16])), 0), }; - let socket = match std::net::UdpSocket::bind(local_addr) { - Err(e) => { - eprintln!("Unable to bind UDP socket: {e}"); - exit(1) - } - Ok(s) => s, - }; - socket.set_nonblocking(true)?; - let socket = UdpSocket::from_std(socket)?; - + let socket = udp::Socket::bind(local_addr)?; let real_local = socket.local_addr().unwrap(); println!( "{} Client connecting: {:?} -> {:?}", @@ -1125,17 +1064,16 @@ mod old { time::Instant, }; - use neqo_common::{event::Provider, qdebug, qinfo, Datagram}; + use neqo_common::{event::Provider, qdebug, qinfo, udp, Datagram}; use neqo_crypto::{AuthenticationStatus, ResumptionToken}; use neqo_transport::{ Connection, ConnectionEvent, EmptyConnectionIdGenerator, Error, Output, State, StreamId, StreamType, }; - use tokio::{net::UdpSocket, time::Sleep}; + use tokio::time::Sleep; use url::Url; - use super::{get_output_file, qlog_new, read_dgram, ready, Args, KeyUpdateState, Ready, Res}; - use crate::emit_datagram; + use super::{get_output_file, qlog_new, ready, Args, KeyUpdateState, Ready, Res}; struct HandlerOld<'b> { streams: HashMap>, @@ -1321,7 +1259,7 @@ mod old { pub struct ClientRunner<'a> { local_addr: SocketAddr, - socket: &'a UdpSocket, + socket: &'a udp::Socket, client: Connection, handler: HandlerOld<'a>, timeout: Option>>, @@ -1331,7 +1269,7 @@ mod old { impl<'a> ClientRunner<'a> { pub fn new( args: &'a Args, - socket: &'a UdpSocket, + socket: &'a udp::Socket, local_addr: SocketAddr, remote_addr: SocketAddr, origin: &str, @@ -1394,7 +1332,7 @@ mod old { match ready(self.socket, self.timeout.as_mut()).await? { Ready::Socket => loop { - let dgram = read_dgram(self.socket, &self.local_addr)?; + let dgram = self.socket.recv(&self.local_addr)?; if dgram.is_none() { break; } @@ -1430,7 +1368,8 @@ mod old { loop { match self.client.process(dgram.take(), Instant::now()) { Output::Datagram(dgram) => { - emit_datagram(self.socket, dgram).await?; + self.socket.writable().await?; + self.socket.send(dgram)?; } Output::Callback(new_timeout) => { qinfo!("Setting timeout of {:?}", new_timeout); diff --git a/neqo-common/Cargo.toml b/neqo-common/Cargo.toml index 25122e9b87..7017ff9600 100644 --- a/neqo-common/Cargo.toml +++ b/neqo-common/Cargo.toml @@ -15,13 +15,16 @@ enum-map = "2.7" env_logger = { version = "0.10", default-features = false } log = { version = "0.4", default-features = false } qlog = "0.12" +quinn-udp = { git = "https://github.com/quinn-rs/quinn/", rev = "a947962131aba8a6521253d03cc948b20098a2d6", optional = true } time = { version = "0.3", features = ["formatting"] } +tokio = { version = "1", features = ["net", "time", "macros", "rt", "rt-multi-thread"], optional = true } [dev-dependencies] test-fixture = { path = "../test-fixture" } [features] ci = [] +udp = ["dep:quinn-udp", "dep:tokio"] [target."cfg(windows)".dependencies.winapi] version = "0.3" diff --git a/neqo-common/src/datagram.rs b/neqo-common/src/datagram.rs index 1729c8ed8d..d6ed43bde1 100644 --- a/neqo-common/src/datagram.rs +++ b/neqo-common/src/datagram.rs @@ -53,6 +53,11 @@ impl Datagram { pub fn ttl(&self) -> Option { self.ttl } + + #[must_use] + pub(crate) fn into_data(self) -> Vec { + self.d + } } impl Deref for Datagram { diff --git a/neqo-common/src/lib.rs b/neqo-common/src/lib.rs index ee97408a41..7b7bf6a163 100644 --- a/neqo-common/src/lib.rs +++ b/neqo-common/src/lib.rs @@ -17,6 +17,8 @@ pub mod log; pub mod qlog; pub mod timer; pub mod tos; +#[cfg(feature = "udp")] +pub mod udp; use std::fmt::Write; diff --git a/neqo-common/src/tos.rs b/neqo-common/src/tos.rs index 80e073a1e4..3610f72750 100644 --- a/neqo-common/src/tos.rs +++ b/neqo-common/src/tos.rs @@ -46,6 +46,12 @@ impl From for IpTosEcn { } } +impl From for IpTosEcn { + fn from(value: IpTos) -> Self { + IpTosEcn::from(value.0 & 0x3) + } +} + /// Diffserv Codepoints, mapped to the upper six bits of the TOS field. /// #[derive(Copy, Clone, PartialEq, Eq, Enum, Default, Debug)] @@ -159,6 +165,12 @@ impl From for IpTosDscp { } } +impl From for IpTosDscp { + fn from(value: IpTos) -> Self { + IpTosDscp::from(value.0 & 0xfc) + } +} + /// The type-of-service field in an IP packet. #[allow(clippy::module_name_repetitions)] #[derive(Copy, Clone, PartialEq, Eq)] @@ -169,22 +181,37 @@ impl From for IpTos { Self(u8::from(v)) } } + impl From for IpTos { fn from(v: IpTosDscp) -> Self { Self(u8::from(v)) } } + impl From<(IpTosDscp, IpTosEcn)> for IpTos { fn from(v: (IpTosDscp, IpTosEcn)) -> Self { Self(u8::from(v.0) | u8::from(v.1)) } } + +impl From<(IpTosEcn, IpTosDscp)> for IpTos { + fn from(v: (IpTosEcn, IpTosDscp)) -> Self { + Self(u8::from(v.0) | u8::from(v.1)) + } +} + impl From for u8 { fn from(v: IpTos) -> Self { v.0 } } +impl From for IpTos { + fn from(v: u8) -> Self { + Self(v) + } +} + impl Debug for IpTos { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("IpTos") @@ -287,4 +314,12 @@ mod tests { let iptos_dscp: IpTos = dscp.into(); assert_eq!(u8::from(iptos_dscp), dscp as u8); } + + #[test] + fn u8_to_iptos() { + let tos = 0x8b; + let iptos: IpTos = (IpTosEcn::Ce, IpTosDscp::Af41).into(); + assert_eq!(tos, u8::from(iptos)); + assert_eq!(IpTos::from(tos), iptos); + } } diff --git a/neqo-common/src/udp.rs b/neqo-common/src/udp.rs new file mode 100644 index 0000000000..7ad0b97625 --- /dev/null +++ b/neqo-common/src/udp.rs @@ -0,0 +1,154 @@ +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(clippy::missing_errors_doc)] // Functions simply delegate to tokio and quinn-udp. +#![allow(clippy::missing_panics_doc)] // Functions simply delegate to tokio and quinn-udp. + +use std::{ + io::{self, IoSliceMut}, + net::{SocketAddr, ToSocketAddrs}, + slice, +}; + +use quinn_udp::{EcnCodepoint, RecvMeta, Transmit, UdpSocketState}; +use tokio::io::Interest; + +use crate::{Datagram, IpTos}; + +pub struct Socket { + socket: tokio::net::UdpSocket, + state: UdpSocketState, +} + +impl Socket { + /// Calls [`std::net::UdpSocket::bind`] and instantiates [`quinn_udp::UdpSocketState`]. + pub fn bind(addr: A) -> Result { + let socket = std::net::UdpSocket::bind(addr)?; + + Ok(Self { + state: quinn_udp::UdpSocketState::new((&socket).into())?, + socket: tokio::net::UdpSocket::from_std(socket)?, + }) + } + + /// See [`tokio::net::UdpSocket::local_addr`]. + pub fn local_addr(&self) -> io::Result { + self.socket.local_addr() + } + + /// See [`tokio::net::UdpSocket::writable`]. + pub async fn writable(&self) -> Result<(), io::Error> { + self.socket.writable().await + } + + /// See [`tokio::net::UdpSocket::readable`]. + pub async fn readable(&self) -> Result<(), io::Error> { + self.socket.readable().await + } + + /// Send the UDP datagram on the specified socket. + pub fn send(&self, d: Datagram) -> io::Result { + let transmit = Transmit { + destination: d.destination(), + ecn: EcnCodepoint::from_bits(Into::::into(d.tos())), + contents: d.into_data().into(), + segment_size: None, + src_ip: None, + }; + + let n = self.socket.try_io(Interest::WRITABLE, || { + self.state + .send((&self.socket).into(), slice::from_ref(&transmit)) + })?; + + assert_eq!(n, 1, "only passed one slice"); + + Ok(n) + } + + /// Receive a UDP datagram on the specified socket. + pub fn recv(&self, local_address: &SocketAddr) -> Result, io::Error> { + let mut buf = [0; u16::MAX as usize]; + + let mut meta = RecvMeta::default(); + + match self.socket.try_io(Interest::READABLE, || { + self.state.recv( + (&self.socket).into(), + &mut [IoSliceMut::new(&mut buf)], + slice::from_mut(&mut meta), + ) + }) { + Ok(n) => { + assert_eq!(n, 1, "only passed one slice"); + } + Err(ref err) + if err.kind() == io::ErrorKind::WouldBlock + || err.kind() == io::ErrorKind::Interrupted => + { + return Ok(None) + } + Err(err) => { + return Err(err); + } + }; + + if meta.len == 0 { + eprintln!("zero length datagram received?"); + return Ok(None); + } + + if meta.len == buf.len() { + eprintln!("Might have received more than {} bytes", buf.len()); + } + + Ok(Some(Datagram::new( + meta.addr, + *local_address, + meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), + None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 + &buf[..meta.len], + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{IpTos, IpTosDscp, IpTosEcn}; + + #[tokio::test] + async fn datagram_tos() -> Result<(), io::Error> { + let sender = Socket::bind("127.0.0.1:0")?; + let receiver_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + let receiver = Socket::bind(receiver_addr)?; + + let datagram = Datagram::new( + sender.local_addr()?, + receiver.local_addr()?, + IpTos::from((IpTosDscp::Le, IpTosEcn::Ect1)), + None, + "Hello, world!".as_bytes().to_vec(), + ); + + sender.writable().await?; + sender.send(datagram.clone())?; + + receiver.readable().await?; + let received_datagram = receiver + .recv(&receiver_addr) + .expect("receive to succeed") + .expect("receive to yield datagram"); + + // Assert that the ECN is correct. + assert_eq!( + IpTosEcn::from(datagram.tos()), + IpTosEcn::from(received_datagram.tos()) + ); + + Ok(()) + } +} diff --git a/neqo-server/Cargo.toml b/neqo-server/Cargo.toml index 7a83685c9f..b2c36ed21f 100644 --- a/neqo-server/Cargo.toml +++ b/neqo-server/Cargo.toml @@ -14,7 +14,7 @@ license.workspace = true clap = { version = "4.4", features = ["derive"] } futures = "0.3" log = { version = "0.4", default-features = false } -neqo-common = { path = "./../neqo-common" } +neqo-common = { path = "./../neqo-common", features = ["udp"] } neqo-crypto = { path = "./../neqo-crypto" } neqo-http3 = { path = "./../neqo-http3" } neqo-qpack = { path = "./../neqo-qpack" } diff --git a/neqo-server/src/main.rs b/neqo-server/src/main.rs index 0c07cb61b7..66450ef5d6 100644 --- a/neqo-server/src/main.rs +++ b/neqo-server/src/main.rs @@ -27,7 +27,7 @@ use futures::{ future::{select, select_all, Either}, FutureExt, }; -use neqo_common::{hex, qdebug, qinfo, qwarn, Datagram, Header, IpTos}; +use neqo_common::{hex, qinfo, qwarn, udp, Datagram, Header}; use neqo_crypto::{ constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256}, generate_ech_keys, init_db, random, AntiReplay, Cipher, @@ -40,7 +40,7 @@ use neqo_transport::{ ConnectionIdGenerator, ConnectionParameters, Output, RandomConnectionIdGenerator, StreamType, Version, }; -use tokio::{net::UdpSocket, time::Sleep}; +use tokio::time::Sleep; use crate::old_https::Http09Server; @@ -305,21 +305,6 @@ impl QuicParameters { } } -async fn emit_packet(socket: &mut UdpSocket, out_dgram: Datagram) { - let sent = match socket.send_to(&out_dgram, &out_dgram.destination()).await { - Err(ref err) => { - if err.kind() != io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted { - eprintln!("UDP send error: {err:?}"); - } - 0 - } - Ok(res) => res, - }; - if sent != out_dgram.len() { - eprintln!("Unable to send all {} bytes of datagram", out_dgram.len()); - } -} - fn qns_read_response(filename: &str) -> Option> { let mut file_path = PathBuf::from("/www"); file_path.push(filename.trim_matches(|p| p == '/')); @@ -578,48 +563,11 @@ impl HttpServer for SimpleServer { } } -fn read_dgram( - socket: &mut UdpSocket, - local_address: &SocketAddr, -) -> Result, io::Error> { - let buf = &mut [0u8; 2048]; - let (sz, remote_addr) = match socket.try_recv_from(&mut buf[..]) { - Err(ref err) - if err.kind() == io::ErrorKind::WouldBlock - || err.kind() == io::ErrorKind::Interrupted => - { - return Ok(None) - } - Err(err) => { - eprintln!("UDP recv error: {err:?}"); - return Err(err); - } - Ok(res) => res, - }; - - if sz == buf.len() { - eprintln!("Might have received more than {} bytes", buf.len()); - } - - if sz == 0 { - eprintln!("zero length datagram received?"); - Ok(None) - } else { - Ok(Some(Datagram::new( - remote_addr, - *local_address, - IpTos::default(), - None, - &buf[..sz], - ))) - } -} - struct ServersRunner { args: Args, server: Box, timeout: Option>>, - sockets: Vec<(SocketAddr, UdpSocket)>, + sockets: Vec<(SocketAddr, udp::Socket)>, } impl ServersRunner { @@ -632,11 +580,11 @@ impl ServersRunner { let sockets = hosts .into_iter() .map(|host| { - let socket = std::net::UdpSocket::bind(host)?; + let socket = udp::Socket::bind(host)?; let local_addr = socket.local_addr()?; println!("Server waiting for connection on: {local_addr:?}"); - socket.set_nonblocking(true)?; - Ok((host, UdpSocket::from_std(socket)?)) + + Ok((host, socket)) }) .collect::>()?; let server = Self::create_server(&args); @@ -683,7 +631,7 @@ impl ServersRunner { } /// Tries to find a socket, but then just falls back to sending from the first. - fn find_socket(&mut self, addr: SocketAddr) -> &mut UdpSocket { + fn find_socket(&mut self, addr: SocketAddr) -> &mut udp::Socket { let ((_host, first_socket), rest) = self.sockets.split_first_mut().unwrap(); rest.iter_mut() .map(|(_host, socket)| socket) @@ -696,12 +644,13 @@ impl ServersRunner { .unwrap_or(first_socket) } - async fn process(&mut self, mut dgram: Option<&Datagram>) { + async fn process(&mut self, mut dgram: Option<&Datagram>) -> Result<(), io::Error> { loop { match self.server.process(dgram.take(), self.args.now()) { Output::Datagram(dgram) => { let socket = self.find_socket(dgram.source()); - emit_packet(socket, dgram).await; + socket.writable().await?; + socket.send(dgram)?; } Output::Callback(new_timeout) => { qinfo!("Setting timeout of {:?}", new_timeout); @@ -709,11 +658,11 @@ impl ServersRunner { break; } Output::None => { - qdebug!("Output::None"); break; } } } + Ok(()) } // Wait for any of the sockets to be readable or the timeout to fire. @@ -740,20 +689,20 @@ impl ServersRunner { match self.ready().await? { Ready::Socket(inx) => loop { let (host, socket) = self.sockets.get_mut(inx).unwrap(); - let dgram = read_dgram(socket, host)?; + let dgram = socket.recv(host)?; if dgram.is_none() { break; } - self.process(dgram.as_ref()).await; + self.process(dgram.as_ref()).await?; }, Ready::Timeout => { self.timeout = None; - self.process(None).await; + self.process(None).await?; } } self.server.process_events(&self.args, self.args.now()); - self.process(None).await; + self.process(None).await?; } } }