Skip to content

Commit ba6b26d

Browse files
committed
test: add test for connect request in udp::handler
1 parent 028e40b commit ba6b26d

File tree

4 files changed

+198
-24
lines changed

4 files changed

+198
-24
lines changed

src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::sync::Arc;
22

33
use log::info;
4+
use torrust_tracker::tracker::statistics::StatsTracker;
45
use torrust_tracker::tracker::tracker::TorrentTracker;
56
use torrust_tracker::{logging, setup, static_time, Configuration};
67

@@ -19,8 +20,11 @@ async fn main() {
1920
}
2021
};
2122

23+
// Initialize stats tracker
24+
let stats_tracker = StatsTracker::new_running_instance();
25+
2226
// Initialize Torrust tracker
23-
let tracker = match TorrentTracker::new(config.clone()) {
27+
let tracker = match TorrentTracker::new(config.clone(), Box::new(stats_tracker)) {
2428
Ok(tracker) => Arc::new(tracker),
2529
Err(error) => {
2630
panic!("{}", error)

src/tracker/statistics.rs

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
use async_trait::async_trait;
12
use std::sync::Arc;
2-
33
use tokio::sync::mpsc::error::SendError;
44
use tokio::sync::mpsc::Sender;
55
use tokio::sync::{mpsc, RwLock, RwLockReadGuard};
66

77
const CHANNEL_BUFFER_SIZE: usize = 65_535;
88

9-
#[derive(Debug)]
9+
#[derive(Debug, PartialEq)]
1010
pub enum TrackerStatisticsEvent {
1111
Tcp4Announce,
1212
Tcp4Scrape,
@@ -61,25 +61,19 @@ pub struct StatsTracker {
6161
}
6262

6363
impl StatsTracker {
64+
pub fn new_running_instance() -> Self {
65+
let mut stats_tracker = Self::new();
66+
stats_tracker.run_worker();
67+
stats_tracker
68+
}
69+
6470
pub fn new() -> Self {
6571
Self {
6672
channel_sender: None,
6773
stats: Arc::new(RwLock::new(TrackerStatistics::new())),
6874
}
6975
}
7076

71-
pub async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics> {
72-
self.stats.read().await
73-
}
74-
75-
pub async fn send_event(&self, event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>> {
76-
if let Some(tx) = &self.channel_sender {
77-
Some(tx.send(event).await)
78-
} else {
79-
None
80-
}
81-
}
82-
8377
pub fn run_worker(&mut self) {
8478
let (tx, mut rx) = mpsc::channel::<TrackerStatisticsEvent>(CHANNEL_BUFFER_SIZE);
8579

@@ -134,3 +128,35 @@ impl StatsTracker {
134128
});
135129
}
136130
}
131+
132+
#[async_trait]
133+
pub trait TrackerStatisticsEventSender: Sync + Send {
134+
async fn send_event(&self, event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>>;
135+
}
136+
137+
#[async_trait]
138+
impl TrackerStatisticsEventSender for StatsTracker {
139+
async fn send_event(&self, event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>> {
140+
if let Some(tx) = &self.channel_sender {
141+
Some(tx.send(event).await)
142+
} else {
143+
None
144+
}
145+
}
146+
}
147+
148+
#[async_trait]
149+
pub trait TrackerStatisticsRepository: Sync + Send {
150+
async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics>;
151+
}
152+
153+
#[async_trait]
154+
impl TrackerStatisticsRepository for StatsTracker {
155+
async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics> {
156+
self.stats.read().await
157+
}
158+
}
159+
160+
pub trait TrackerStatsService: TrackerStatisticsEventSender + TrackerStatisticsRepository {}
161+
162+
impl TrackerStatsService for StatsTracker {}

src/tracker/tracker.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::databases::database::Database;
1212
use crate::mode::TrackerMode;
1313
use crate::peer::TorrentPeer;
1414
use crate::protocol::common::InfoHash;
15-
use crate::statistics::{StatsTracker, TrackerStatistics, TrackerStatisticsEvent};
15+
use crate::statistics::{TrackerStatistics, TrackerStatisticsEvent, TrackerStatsService};
1616
use crate::tracker::key;
1717
use crate::tracker::key::AuthKey;
1818
use crate::tracker::torrent::{TorrentEntry, TorrentError, TorrentStats};
@@ -24,19 +24,13 @@ pub struct TorrentTracker {
2424
keys: RwLock<std::collections::HashMap<String, AuthKey>>,
2525
whitelist: RwLock<std::collections::HashSet<InfoHash>>,
2626
torrents: RwLock<std::collections::BTreeMap<InfoHash, TorrentEntry>>,
27-
stats_tracker: StatsTracker,
27+
stats_tracker: Box<dyn TrackerStatsService>,
2828
database: Box<dyn Database>,
2929
}
3030

3131
impl TorrentTracker {
32-
pub fn new(config: Arc<Configuration>) -> Result<TorrentTracker, r2d2::Error> {
32+
pub fn new(config: Arc<Configuration>, stats_tracker: Box<dyn TrackerStatsService>) -> Result<TorrentTracker, r2d2::Error> {
3333
let database = database::connect_database(&config.db_driver, &config.db_path)?;
34-
let mut stats_tracker = StatsTracker::new();
35-
36-
// starts a thread for updating tracker stats
37-
if config.tracker_usage_statistics {
38-
stats_tracker.run_worker();
39-
}
4034

4135
Ok(TorrentTracker {
4236
config: config.clone(),

src/udp/handlers.rs

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,153 @@ fn handle_error(e: ServerError, transaction_id: TransactionId) -> Response {
236236
message: message.into(),
237237
})
238238
}
239+
240+
#[cfg(test)]
241+
mod tests {
242+
use std::{
243+
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
244+
sync::Arc,
245+
};
246+
247+
use tokio::sync::{mpsc::error::SendError, RwLock, RwLockReadGuard};
248+
249+
use crate::{
250+
protocol::utils::get_connection_id,
251+
statistics::{
252+
StatsTracker, TrackerStatistics, TrackerStatisticsEvent, TrackerStatisticsEventSender, TrackerStatisticsRepository,
253+
TrackerStatsService,
254+
},
255+
tracker::tracker::TorrentTracker,
256+
udp::handle_connect,
257+
Configuration,
258+
};
259+
use aquatic_udp_protocol::{ConnectRequest, ConnectResponse, Response, TransactionId};
260+
use async_trait::async_trait;
261+
262+
fn default_tracker_config() -> Arc<Configuration> {
263+
Arc::new(Configuration::default())
264+
}
265+
266+
fn initialized_tracker() -> Arc<TorrentTracker> {
267+
Arc::new(TorrentTracker::new(default_tracker_config(), Box::new(StatsTracker::new_running_instance())).unwrap())
268+
}
269+
270+
fn sample_remote_addr() -> SocketAddr {
271+
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
272+
}
273+
274+
fn sample_connect_request() -> ConnectRequest {
275+
ConnectRequest {
276+
transaction_id: TransactionId(0i32),
277+
}
278+
}
279+
280+
fn sample_ipv4_socket_address() -> SocketAddr {
281+
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
282+
}
283+
284+
fn sample_ipv6_socket_address() -> SocketAddr {
285+
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080)
286+
}
287+
288+
#[tokio::test]
289+
async fn a_connect_response_should_contain_the_same_transaction_id_as_the_connect_request() {
290+
let request = ConnectRequest {
291+
transaction_id: TransactionId(0i32),
292+
};
293+
294+
let response = handle_connect(sample_remote_addr(), &request, initialized_tracker())
295+
.await
296+
.unwrap();
297+
298+
assert_eq!(
299+
response,
300+
Response::Connect(ConnectResponse {
301+
connection_id: get_connection_id(&sample_remote_addr()),
302+
transaction_id: request.transaction_id
303+
})
304+
);
305+
}
306+
307+
#[tokio::test]
308+
async fn a_connect_response_should_contain_a_new_connection_id() {
309+
let request = ConnectRequest {
310+
transaction_id: TransactionId(0i32),
311+
};
312+
313+
let response = handle_connect(sample_remote_addr(), &request, initialized_tracker())
314+
.await
315+
.unwrap();
316+
317+
assert_eq!(
318+
response,
319+
Response::Connect(ConnectResponse {
320+
connection_id: get_connection_id(&sample_remote_addr()),
321+
transaction_id: request.transaction_id
322+
})
323+
);
324+
}
325+
326+
struct TrackerStatsServiceMock {
327+
stats: Arc<RwLock<TrackerStatistics>>,
328+
expected_event: Option<TrackerStatisticsEvent>,
329+
}
330+
331+
impl TrackerStatsServiceMock {
332+
fn new() -> Self {
333+
Self {
334+
stats: Arc::new(RwLock::new(TrackerStatistics::new())),
335+
expected_event: None,
336+
}
337+
}
338+
339+
fn should_throw_event(&mut self, expected_event: TrackerStatisticsEvent) {
340+
self.expected_event = Some(expected_event);
341+
}
342+
}
343+
344+
#[async_trait]
345+
impl TrackerStatisticsEventSender for TrackerStatsServiceMock {
346+
async fn send_event(&self, _event: TrackerStatisticsEvent) -> Option<Result<(), SendError<TrackerStatisticsEvent>>> {
347+
if self.expected_event.is_some() {
348+
assert_eq!(_event, *self.expected_event.as_ref().unwrap());
349+
}
350+
None
351+
}
352+
}
353+
354+
#[async_trait]
355+
impl TrackerStatisticsRepository for TrackerStatsServiceMock {
356+
async fn get_stats(&self) -> RwLockReadGuard<'_, TrackerStatistics> {
357+
self.stats.read().await
358+
}
359+
}
360+
361+
impl TrackerStatsService for TrackerStatsServiceMock {}
362+
363+
#[tokio::test]
364+
async fn it_should_send_the_upd4_connect_event_when_a_client_tries_to_connect_using_a_ip4_socket_address() {
365+
let mut tracker_stats_service = Box::new(TrackerStatsServiceMock::new());
366+
367+
let client_socket_address = sample_ipv4_socket_address();
368+
tracker_stats_service.should_throw_event(TrackerStatisticsEvent::Udp4Connect);
369+
370+
let torrent_tracker = Arc::new(TorrentTracker::new(default_tracker_config(), tracker_stats_service).unwrap());
371+
handle_connect(client_socket_address, &sample_connect_request(), torrent_tracker)
372+
.await
373+
.unwrap();
374+
}
375+
376+
#[tokio::test]
377+
async fn it_should_send_the_upd6_connect_event_when_a_client_tries_to_connect_using_a_ip6_socket_address() {
378+
let mut tracker_stats_service = Box::new(TrackerStatsServiceMock::new());
379+
380+
let client_socket_address = sample_ipv6_socket_address();
381+
tracker_stats_service.should_throw_event(TrackerStatisticsEvent::Udp6Connect);
382+
383+
let torrent_tracker = Arc::new(TorrentTracker::new(default_tracker_config(), tracker_stats_service).unwrap());
384+
handle_connect(client_socket_address, &sample_connect_request(), torrent_tracker)
385+
.await
386+
.unwrap();
387+
}
388+
}

0 commit comments

Comments
 (0)