Skip to content

Commit f46df38

Browse files
committed
feat: gracefully shutdown udp servers
1 parent f0c0e95 commit f46df38

File tree

5 files changed

+82
-26
lines changed

5 files changed

+82
-26
lines changed

Cargo.lock

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ config = "0.11"
3131
derive_more = "0.99"
3232
thiserror = "1.0"
3333
aquatic_udp_protocol = { git = "https://github.com/greatest-ape/aquatic" }
34+
futures = "0.3.21"

src/main.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@ use torrust_tracker::torrust_http_tracker::server::HttpServer;
77

88
#[tokio::main]
99
async fn main() {
10+
// torrust config
1011
let config = match Configuration::load_from_file() {
1112
Ok(config) => Arc::new(config),
1213
Err(error) => {
1314
panic!("{}", error)
1415
}
1516
};
1617

17-
logging::setup_logging(&config);
18-
1918
// the singleton torrent tracker that gets passed to the HTTP and UDP server
2019
let tracker = Arc::new(TorrentTracker::new(config.clone()).unwrap_or_else(|e| {
2120
panic!("{}", e)
2221
}));
2322

23+
logging::setup_logging(&config);
24+
2425
// load persistent torrents if enabled
2526
if config.persistence {
2627
info!("Loading persistent torrents into memory...");
@@ -38,10 +39,17 @@ async fn main() {
3839
let _api_server = start_api_server(&config.http_api, tracker.clone());
3940
}
4041

42+
let (tx, rx) = tokio::sync::watch::channel(false);
43+
let mut udp_server_handles = Vec::new();
44+
4145
// start the udp blocks
4246
for udp_tracker in &config.udp_trackers {
47+
// used to send kill signal to thread
48+
4349
if udp_tracker.enabled {
44-
let _ = start_udp_tracker_server(&udp_tracker, tracker.clone()).await;
50+
udp_server_handles.push(
51+
start_udp_tracker_server(&udp_tracker, tracker.clone(), rx.clone()).await
52+
)
4553
}
4654
}
4755

@@ -52,10 +60,16 @@ async fn main() {
5260
}
5361

5462
// handle the signals here
55-
let ctrl_c = tokio::signal::ctrl_c();
5663
tokio::select! {
57-
_ = ctrl_c => {
64+
_ = tokio::signal::ctrl_c() => {
5865
info!("Torrust shutting down..");
66+
67+
// send kill signal
68+
let _ = tx.send(true);
69+
70+
// await for all udp servers to shutdown
71+
futures::future::join_all(udp_server_handles).await;
72+
5973
// Save torrents if enabled
6074
if config.persistence {
6175
info!("Saving torrents into SQL from memory...");
@@ -118,13 +132,13 @@ fn start_http_tracker_server(config: &HttpTrackerConfig, tracker: Arc<TorrentTra
118132
})
119133
}
120134

121-
async fn start_udp_tracker_server(config: &UdpTrackerConfig, tracker: Arc<TorrentTracker>) -> JoinHandle<()> {
122-
let udp_server = UdpServer::new(tracker, config).await.unwrap_or_else(|e| {
135+
async fn start_udp_tracker_server(config: &UdpTrackerConfig, tracker: Arc<TorrentTracker>, rx: tokio::sync::watch::Receiver<bool>) -> JoinHandle<()> {
136+
let udp_server = UdpServer::new(tracker, &config.bind_address).await.unwrap_or_else(|e| {
123137
panic!("Could not start UDP server: {}", e);
124138
});
125139

126140
info!("Starting UDP server on: {}", config.bind_address);
127141
tokio::spawn(async move {
128-
udp_server.start().await;
142+
udp_server.start(rx).await;
129143
})
130144
}

src/torrust_udp_tracker/handlers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub async fn authenticate(info_hash: &InfoHash, tracker: Arc<TorrentTracker>) ->
2424
}
2525
}
2626

27-
pub async fn handle_packet(remote_addr: SocketAddr, payload: &[u8], tracker: Arc<TorrentTracker>) -> Response {
27+
pub async fn handle_packet(remote_addr: SocketAddr, payload: Vec<u8>, tracker: Arc<TorrentTracker>) -> Response {
2828
match Request::from_bytes(&payload[..payload.len()], MAX_SCRAPE_TORRENTS).map_err(|_| ServerError::InternalServerError) {
2929
Ok(request) => {
3030
let transaction_id = match &request {

src/torrust_udp_tracker/server.rs

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,54 @@
1+
use std::future::Future;
12
use std::io::Cursor;
23
use std::net::{SocketAddr};
34
use std::sync::Arc;
45
use aquatic_udp_protocol::{Response};
5-
use log::debug;
6+
use log::{debug, info};
67
use tokio::net::UdpSocket;
7-
use crate::{TorrentTracker, UdpTrackerConfig};
8+
use crate::{TorrentTracker};
89
use crate::torrust_udp_tracker::{handle_packet, MAX_PACKET_SIZE};
910

1011
pub struct UdpServer {
11-
socket: UdpSocket,
12+
socket: Arc<UdpSocket>,
1213
tracker: Arc<TorrentTracker>,
1314
}
1415

1516
impl UdpServer {
16-
pub async fn new(tracker: Arc<TorrentTracker>, config: &UdpTrackerConfig) -> Result<UdpServer, std::io::Error> {
17-
let srv = UdpSocket::bind(&config.bind_address).await?;
17+
pub async fn new(tracker: Arc<TorrentTracker>, bind_address: &str) -> tokio::io::Result<UdpServer> {
18+
let socket = UdpSocket::bind(bind_address).await?;
1819

1920
Ok(UdpServer {
20-
socket: srv,
21+
socket: Arc::new(socket),
2122
tracker,
2223
})
2324
}
2425

25-
pub async fn start(&self) {
26+
pub async fn start(&self, rx: tokio::sync::watch::Receiver<bool>) {
2627
loop {
28+
let mut rx = rx.clone();
2729
let mut data = [0; MAX_PACKET_SIZE];
28-
if let Ok((valid_bytes, remote_addr)) = self.socket.recv_from(&mut data).await {
29-
let data = &data[..valid_bytes];
30-
debug!("Received {} bytes from {}", data.len(), remote_addr);
31-
debug!("{:?}", data);
32-
let response = handle_packet(remote_addr, data, self.tracker.clone()).await;
33-
self.send_response(remote_addr, response).await;
30+
let socket = self.socket.clone();
31+
let tracker = self.tracker.clone();
32+
33+
tokio::select! {
34+
_ = rx.changed() => {
35+
info!("Stopping UDP server: {}...", socket.local_addr().unwrap());
36+
break;
37+
}
38+
Ok((valid_bytes, remote_addr)) = socket.recv_from(&mut data) => {
39+
let payload = data[..valid_bytes].to_vec();
40+
41+
debug!("Received {} bytes from {}", payload.len(), remote_addr);
42+
debug!("{:?}", payload);
43+
44+
let response = handle_packet(remote_addr, payload, tracker).await;
45+
UdpServer::send_response(socket, remote_addr, response).await;
46+
}
3447
}
3548
}
3649
}
3750

38-
async fn send_response(&self, remote_addr: SocketAddr, response: Response) {
51+
async fn send_response(socket: Arc<UdpSocket>, remote_addr: SocketAddr, response: Response) {
3952
debug!("sending response to: {:?}", &remote_addr);
4053

4154
let buffer = vec![0u8; MAX_PACKET_SIZE];
@@ -47,14 +60,14 @@ impl UdpServer {
4760
let inner = cursor.get_ref();
4861

4962
debug!("{:?}", &inner[..position]);
50-
self.send_packet(&remote_addr, &inner[..position]).await;
63+
UdpServer::send_packet(socket, &remote_addr, &inner[..position]).await;
5164
}
5265
Err(_) => { debug!("could not write response to bytes."); }
5366
}
5467
}
5568

56-
async fn send_packet(&self, remote_addr: &SocketAddr, payload: &[u8]) {
69+
async fn send_packet(socket: Arc<UdpSocket>, remote_addr: &SocketAddr, payload: &[u8]) {
5770
// doesn't matter if it reaches or not
58-
let _ = self.socket.send_to(payload, remote_addr).await;
71+
let _ = socket.send_to(payload, remote_addr).await;
5972
}
6073
}

0 commit comments

Comments
 (0)