Skip to content

Commit 5b09e60

Browse files
authored
Add TcpSocket, a basic TCP socket builder (#1358)
This provides `TcpSocket`, a basic API for building a TCP socket. The goal is not to provide comprehensive coverage of all system options, but to provide an API for the most common cases. This is added now as a replacement for the removal of `TcpStream::connect_std`. The `connect_std` function from v0.6 was used until now as the strategy to set socket option before obtaining a mio TcpStream. Providing some strategy for customizing a `TcpStream` is required for Hyper to be able to upgrade.
1 parent 87a15fc commit 5b09e60

File tree

12 files changed

+316
-126
lines changed

12 files changed

+316
-126
lines changed

src/net/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
cfg_tcp! {
1111
mod tcp;
12-
pub use self::tcp::{TcpListener, TcpStream};
12+
pub use self::tcp::{TcpListener, TcpSocket, TcpStream};
1313
}
1414

1515
cfg_udp! {

src/net/tcp/listener.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
55
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
66
use std::{fmt, io};
77

8-
use super::TcpStream;
8+
use super::{TcpSocket, TcpStream};
99
use crate::io_source::IoSource;
1010
use crate::{event, sys, Interest, Registry, Token};
1111

@@ -49,7 +49,20 @@ impl TcpListener {
4949
/// 3. Bind the socket to the specified address.
5050
/// 4. Calls `listen` on the socket to prepare it to receive new connections.
5151
pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
52-
sys::tcp::bind(addr).map(TcpListener::from_std)
52+
let socket = TcpSocket::new_for_addr(addr)?;
53+
54+
// On platforms with Berkeley-derived sockets, this allows to quickly
55+
// rebind a socket, without needing to wait for the OS to clean up the
56+
// previous one.
57+
//
58+
// On Windows, this allows rebinding sockets which are actively in use,
59+
// which allows “socket hijacking”, so we explicitly don't set it here.
60+
// https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
61+
#[cfg(not(windows))]
62+
socket.set_reuseaddr(true)?;
63+
64+
socket.bind(addr)?;
65+
socket.listen(1024)
5366
}
5467

5568
/// Creates a new `TcpListener` from a standard `net::TcpListener`.

src/net/tcp/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
mod listener;
22
pub use self::listener::TcpListener;
33

4+
mod socket;
5+
pub use self::socket::TcpSocket;
6+
47
mod stream;
58
pub use self::stream::TcpStream;

src/net/tcp/socket.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
use crate::net::{TcpStream, TcpListener};
2+
use crate::sys;
3+
4+
use std::io;
5+
use std::mem;
6+
use std::net::SocketAddr;
7+
#[cfg(unix)]
8+
use std::os::unix::io::{AsRawFd, RawFd, FromRawFd};
9+
10+
/// A non-blocking TCP socket used to configure a stream or listener.
11+
///
12+
/// The `TcpSocket` type wraps the operating-system's socket handle. This type
13+
/// is used to configure the socket before establishing a connection or start
14+
/// listening for inbound connections.
15+
///
16+
/// The socket will be closed when the value is dropped.
17+
#[derive(Debug)]
18+
pub struct TcpSocket {
19+
sys: sys::tcp::TcpSocket,
20+
}
21+
22+
impl TcpSocket {
23+
/// Create a new IPv4 TCP socket.
24+
///
25+
/// This calls `socket(2)`.
26+
pub fn new_v4() -> io::Result<TcpSocket> {
27+
sys::tcp::new_v4_socket().map(|sys| TcpSocket {
28+
sys
29+
})
30+
}
31+
32+
/// Create a new IPv6 TCP socket.
33+
///
34+
/// This calls `socket(2)`.
35+
pub fn new_v6() -> io::Result<TcpSocket> {
36+
sys::tcp::new_v6_socket().map(|sys| TcpSocket {
37+
sys
38+
})
39+
}
40+
41+
pub(crate) fn new_for_addr(addr: SocketAddr) -> io::Result<TcpSocket> {
42+
if addr.is_ipv4() {
43+
TcpSocket::new_v4()
44+
} else {
45+
TcpSocket::new_v6()
46+
}
47+
}
48+
49+
/// Bind `addr` to the TCP socket.
50+
pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
51+
sys::tcp::bind(self.sys, addr)
52+
}
53+
54+
/// Connect the socket to `addr`.
55+
///
56+
/// This consumes the socket and performs the connect operation. Once the
57+
/// connection completes, the socket is now a non-blocking `TcpStream` and
58+
/// can be used as such.
59+
pub fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
60+
let stream = sys::tcp::connect(self.sys, addr)?;
61+
62+
// Don't close the socket
63+
mem::forget(self);
64+
Ok(TcpStream::from_std(stream))
65+
}
66+
67+
/// Listen for inbound connections, converting the socket to a
68+
/// `TcpListener`.
69+
pub fn listen(self, backlog: u32) -> io::Result<TcpListener> {
70+
let listener = sys::tcp::listen(self.sys, backlog)?;
71+
72+
// Don't close the socket
73+
mem::forget(self);
74+
Ok(TcpListener::from_std(listener))
75+
}
76+
77+
/// Sets the value of `SO_REUSEADDR` on this socket.
78+
pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
79+
sys::tcp::set_reuseaddr(self.sys, reuseaddr)
80+
}
81+
}
82+
83+
impl Drop for TcpSocket {
84+
fn drop(&mut self) {
85+
sys::tcp::close(self.sys);
86+
}
87+
}
88+
89+
#[cfg(unix)]
90+
impl AsRawFd for TcpSocket {
91+
fn as_raw_fd(&self) -> RawFd {
92+
self.sys
93+
}
94+
}
95+
96+
#[cfg(unix)]
97+
impl FromRawFd for TcpSocket {
98+
/// Converts a `RawFd` to a `TcpStream`.
99+
///
100+
/// # Notes
101+
///
102+
/// The caller is responsible for ensuring that the socket is in
103+
/// non-blocking mode.
104+
unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket {
105+
TcpSocket { sys: fd }
106+
}
107+
}

src/net/tcp/stream.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
77
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
88

99
use crate::io_source::IoSource;
10-
use crate::{event, sys, Interest, Registry, Token};
10+
use crate::{event, Interest, Registry, Token};
11+
use crate::net::TcpSocket;
1112

1213
/// A non-blocking TCP stream between a local socket and a remote socket.
1314
///
@@ -48,7 +49,8 @@ impl TcpStream {
4849
/// Create a new TCP stream and issue a non-blocking connect to the
4950
/// specified address.
5051
pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
51-
sys::tcp::connect(addr).map(TcpStream::from_std)
52+
let socket = TcpSocket::new_for_addr(addr)?;
53+
socket.connect(addr)
5254
}
5355

5456
/// Creates a new `TcpStream` from a standard `net::TcpStream`.

src/sys/shell/tcp.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,33 @@
11
use std::io;
22
use std::net::{self, SocketAddr};
33

4-
pub fn connect(_: SocketAddr) -> io::Result<net::TcpStream> {
4+
pub(crate) type TcpSocket = i32;
5+
6+
pub(crate) fn new_v4_socket() -> io::Result<TcpSocket> {
7+
os_required!();
8+
}
9+
10+
pub(crate) fn new_v6_socket() -> io::Result<TcpSocket> {
11+
os_required!();
12+
}
13+
14+
pub(crate) fn bind(_socket: TcpSocket, _addr: SocketAddr) -> io::Result<()> {
15+
os_required!();
16+
}
17+
18+
pub(crate) fn connect(_: TcpSocket, _addr: SocketAddr) -> io::Result<net::TcpStream> {
19+
os_required!();
20+
}
21+
22+
pub(crate) fn listen(_: TcpSocket, _: u32) -> io::Result<net::TcpListener> {
23+
os_required!();
24+
}
25+
26+
pub(crate) fn close(_: TcpSocket) {
527
os_required!();
628
}
729

8-
pub fn bind(_: SocketAddr) -> io::Result<net::TcpListener> {
30+
pub(crate) fn set_reuseaddr(_: TcpSocket, _: bool) -> io::Result<()> {
931
os_required!();
1032
}
1133

src/sys/unix/net.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#[cfg(all(feature = "os-poll", any(feature = "tcp", feature = "udp")))]
22
use std::net::SocketAddr;
33

4-
#[cfg(all(feature = "os-poll", any(feature = "tcp", feature = "udp")))]
4+
#[cfg(all(feature = "os-poll", any(feature = "udp")))]
55
pub(crate) fn new_ip_socket(
66
addr: SocketAddr,
77
socket_type: libc::c_int,

src/sys/unix/tcp.rs

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,59 @@ use std::mem::{size_of, MaybeUninit};
33
use std::net::{self, SocketAddr};
44
use std::os::unix::io::{AsRawFd, FromRawFd};
55

6-
use crate::sys::unix::net::{new_ip_socket, socket_addr, to_socket_addr};
6+
use crate::sys::unix::net::{new_socket, socket_addr, to_socket_addr};
77

8-
pub fn connect(addr: SocketAddr) -> io::Result<net::TcpStream> {
9-
new_ip_socket(addr, libc::SOCK_STREAM)
10-
.and_then(|socket| {
11-
let (raw_addr, raw_addr_length) = socket_addr(&addr);
12-
syscall!(connect(socket, raw_addr, raw_addr_length))
13-
.or_else(|err| match err {
14-
// Connect hasn't finished, but that is fine.
15-
ref err if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(0),
16-
err => Err(err),
17-
})
18-
.map(|_| socket)
19-
.map_err(|err| {
20-
// Close the socket if we hit an error, ignoring the error
21-
// from closing since we can't pass back two errors.
22-
let _ = unsafe { libc::close(socket) };
23-
err
24-
})
25-
})
26-
.map(|socket| unsafe { net::TcpStream::from_raw_fd(socket) })
8+
pub type TcpSocket = libc::c_int;
9+
10+
pub(crate) fn new_v4_socket() -> io::Result<TcpSocket> {
11+
new_socket(libc::AF_INET, libc::SOCK_STREAM)
2712
}
2813

29-
pub fn bind(addr: SocketAddr) -> io::Result<net::TcpListener> {
30-
new_ip_socket(addr, libc::SOCK_STREAM).and_then(|socket| {
31-
// Set SO_REUSEADDR (mirrors what libstd does).
32-
syscall!(setsockopt(
33-
socket,
34-
libc::SOL_SOCKET,
35-
libc::SO_REUSEADDR,
36-
&1 as *const libc::c_int as *const libc::c_void,
37-
size_of::<libc::c_int>() as libc::socklen_t,
38-
))
39-
.and_then(|_| {
40-
let (raw_addr, raw_addr_length) = socket_addr(&addr);
41-
syscall!(bind(socket, raw_addr, raw_addr_length))
42-
})
43-
.and_then(|_| syscall!(listen(socket, 1024)))
44-
.map_err(|err| {
45-
// Close the socket if we hit an error, ignoring the error
46-
// from closing since we can't pass back two errors.
47-
let _ = unsafe { libc::close(socket) };
48-
err
49-
})
50-
.map(|_| unsafe { net::TcpListener::from_raw_fd(socket) })
51-
})
14+
pub(crate) fn new_v6_socket() -> io::Result<TcpSocket> {
15+
new_socket(libc::AF_INET6, libc::SOCK_STREAM)
16+
}
17+
18+
pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> {
19+
let (raw_addr, raw_addr_length) = socket_addr(&addr);
20+
syscall!(bind(socket, raw_addr, raw_addr_length))?;
21+
Ok(())
22+
}
23+
24+
pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result<net::TcpStream> {
25+
let (raw_addr, raw_addr_length) = socket_addr(&addr);
26+
27+
match syscall!(connect(socket, raw_addr, raw_addr_length)) {
28+
Err(err) if err.raw_os_error() != Some(libc::EINPROGRESS) => {
29+
Err(err)
30+
}
31+
_ => {
32+
Ok(unsafe { net::TcpStream::from_raw_fd(socket) })
33+
}
34+
}
35+
}
36+
37+
pub(crate) fn listen(socket: TcpSocket, backlog: u32) -> io::Result<net::TcpListener> {
38+
use std::convert::TryInto;
39+
40+
let backlog = backlog.try_into().unwrap_or(i32::max_value());
41+
syscall!(listen(socket, backlog))?;
42+
Ok(unsafe { net::TcpListener::from_raw_fd(socket) })
43+
}
44+
45+
pub(crate) fn close(socket: TcpSocket) {
46+
let _ = unsafe { net::TcpStream::from_raw_fd(socket) };
47+
}
48+
49+
pub(crate) fn set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<()> {
50+
let val: libc::c_int = if reuseaddr { 1 } else { 0 };
51+
syscall!(setsockopt(
52+
socket,
53+
libc::SOL_SOCKET,
54+
libc::SO_REUSEADDR,
55+
&val as *const libc::c_int as *const libc::c_void,
56+
size_of::<libc::c_int>() as libc::socklen_t,
57+
))?;
58+
Ok(())
5259
}
5360

5461
pub fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> {

src/sys/windows/net.rs

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
use std::io;
22
use std::mem::size_of_val;
33
use std::net::SocketAddr;
4-
#[cfg(all(feature = "os-poll", feature = "tcp"))]
5-
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
64
use std::sync::Once;
75

86
use winapi::ctypes::c_int;
97
use winapi::shared::ws2def::SOCKADDR;
108
use winapi::um::winsock2::{
11-
ioctlsocket, socket, FIONBIO, INVALID_SOCKET, PF_INET, PF_INET6, SOCKET,
9+
ioctlsocket, socket, FIONBIO, INVALID_SOCKET, SOCKET,
1210
};
1311

1412
/// Initialise the network stack for Windows.
@@ -23,12 +21,19 @@ pub(crate) fn init() {
2321
}
2422

2523
/// Create a new non-blocking socket.
26-
pub(crate) fn new_socket(addr: SocketAddr, socket_type: c_int) -> io::Result<SOCKET> {
24+
#[cfg(feature = "udp")]
25+
pub(crate) fn new_ip_socket(addr: SocketAddr, socket_type: c_int) -> io::Result<SOCKET> {
26+
use winapi::um::winsock2::{PF_INET, PF_INET6};
27+
2728
let domain = match addr {
2829
SocketAddr::V4(..) => PF_INET,
2930
SocketAddr::V6(..) => PF_INET6,
3031
};
3132

33+
new_socket(domain, socket_type)
34+
}
35+
36+
pub(crate) fn new_socket(domain: c_int, socket_type: c_int) -> io::Result<SOCKET> {
3237
syscall!(
3338
socket(domain, socket_type, 0),
3439
PartialEq::eq,
@@ -51,19 +56,3 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (*const SOCKADDR, c_int) {
5156
),
5257
}
5358
}
54-
55-
#[cfg(all(feature = "os-poll", feature = "tcp"))]
56-
pub(crate) fn inaddr_any(other: SocketAddr) -> SocketAddr {
57-
match other {
58-
SocketAddr::V4(..) => {
59-
let any = Ipv4Addr::new(0, 0, 0, 0);
60-
let addr = SocketAddrV4::new(any, 0);
61-
SocketAddr::V4(addr)
62-
}
63-
SocketAddr::V6(..) => {
64-
let any = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
65-
let addr = SocketAddrV6::new(any, 0, 0, 0);
66-
SocketAddr::V6(addr)
67-
}
68-
}
69-
}

0 commit comments

Comments
 (0)