Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 84 additions & 16 deletions src/http/axum_implementation/extractors/announce_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,95 @@ where
type Rejection = Response;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let raw_query = parts.uri.query();

if raw_query.is_none() {
return Err(responses::error::Error::from(ParseAnnounceQueryError::MissingParams {
location: Location::caller(),
})
.into_response());
match extract_announce_from(parts.uri.query()) {
Ok(announce_request) => Ok(ExtractRequest(announce_request)),
Err(error) => Err(error.into_response()),
}
}
}

let query = raw_query.unwrap().parse::<Query>();
fn extract_announce_from(maybe_raw_query: Option<&str>) -> Result<Announce, responses::error::Error> {
if maybe_raw_query.is_none() {
return Err(responses::error::Error::from(ParseAnnounceQueryError::MissingParams {
location: Location::caller(),
}));
}

if let Err(error) = query {
return Err(responses::error::Error::from(error).into_response());
}
let query = maybe_raw_query.unwrap().parse::<Query>();

let announce_request = Announce::try_from(query.unwrap());
if let Err(error) = query {
return Err(responses::error::Error::from(error));
}

if let Err(error) = announce_request {
return Err(responses::error::Error::from(error).into_response());
}
let announce_request = Announce::try_from(query.unwrap());

if let Err(error) = announce_request {
return Err(responses::error::Error::from(error));
}

Ok(announce_request.unwrap())
}

#[cfg(test)]
mod tests {
use std::str::FromStr;

use super::extract_announce_from;
use crate::http::axum_implementation::requests::announce::{Announce, Compact, Event};
use crate::http::axum_implementation::responses::error::Error;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::peer;

fn assert_error_response(error: &Error, error_message: &str) {
assert!(
error.failure_reason.contains(error_message),
"Error response does not contain message: '{error_message}'. Error: {error:?}"
);
}

#[test]
fn it_should_extract_the_announce_request_from_the_url_query_params() {
let raw_query = "info_hash=%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0&peer_addr=2.137.87.41&downloaded=0&uploaded=0&peer_id=-qB00000000000000001&port=17548&left=0&event=completed&compact=0";

let announce = extract_announce_from(Some(raw_query)).unwrap();

assert_eq!(
announce,
Announce {
info_hash: InfoHash::from_str("3b245504cf5f11bbdbe1201cea6a6bf45aee1bc0").unwrap(),
peer_id: peer::Id(*b"-qB00000000000000001"),
port: 17548,
downloaded: Some(0),
uploaded: Some(0),
left: Some(0),
event: Some(Event::Completed),
compact: Some(Compact::NotAccepted),
}
);
}

#[test]
fn it_should_reject_a_request_without_query_params() {
let response = extract_announce_from(None).unwrap_err();

assert_error_response(
&response,
"Cannot parse query params for announce request: missing query params for announce request",
);
}

#[test]
fn it_should_reject_a_request_with_a_query_that_cannot_be_parsed() {
let invalid_query = "param1=value1=value2";
let response = extract_announce_from(Some(invalid_query)).unwrap_err();

assert_error_response(&response, "Cannot parse query params");
}

#[test]
fn it_should_reject_a_request_with_a_query_that_cannot_be_parsed_into_an_announce_request() {
let response = extract_announce_from(Some("param1=value1")).unwrap_err();

Ok(ExtractRequest(announce_request.unwrap()))
assert_error_response(&response, "Cannot parse query params for announce request");
}
}
101 changes: 70 additions & 31 deletions src/http/axum_implementation/extractors/key.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Wrapper for Axum `Path` extractor to return custom errors.
use std::panic::Location;

use axum::async_trait;
use axum::extract::rejection::PathRejection;
use axum::extract::{FromRequestParts, Path};
use axum::http::request::Parts;
use axum::response::{IntoResponse, Response};
Expand All @@ -19,37 +21,74 @@ where
type Rejection = Response;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match Path::<KeyParam>::from_request_parts(parts, state).await {
Ok(key_param) => {
let Ok(key) = key_param.0.value().parse::<Key>() else {
return Err(responses::error::Error::from(
auth::Error::InvalidKeyFormat {
location: Location::caller()
})
.into_response())
};
Ok(Extract(key))
}
Err(rejection) => match rejection {
axum::extract::rejection::PathRejection::FailedToDeserializePathParams(_) => {
return Err(responses::error::Error::from(auth::Error::InvalidKeyFormat {
location: Location::caller(),
})
.into_response())
}
axum::extract::rejection::PathRejection::MissingPathParams(_) => {
return Err(responses::error::Error::from(auth::Error::MissingAuthKey {
location: Location::caller(),
})
.into_response())
}
_ => {
return Err(responses::error::Error::from(auth::Error::CannotExtractKeyParam {
location: Location::caller(),
})
.into_response())
}
},
// Extract `key` from URL path with Axum `Path` extractor
let maybe_path_with_key = Path::<KeyParam>::from_request_parts(parts, state).await;

match extract_key(maybe_path_with_key) {
Ok(key) => Ok(Extract(key)),
Err(error) => Err(error.into_response()),
}
}
}

fn extract_key(path_extractor_result: Result<Path<KeyParam>, PathRejection>) -> Result<Key, responses::error::Error> {
match path_extractor_result {
Ok(key_param) => match parse_key(&key_param.0.value()) {
Ok(key) => Ok(key),
Err(error) => Err(error),
},
Err(path_rejection) => Err(custom_error(&path_rejection)),
}
}

fn parse_key(key: &str) -> Result<Key, responses::error::Error> {
let key = key.parse::<Key>();

match key {
Ok(key) => Ok(key),
Err(_parse_key_error) => Err(responses::error::Error::from(auth::Error::InvalidKeyFormat {
location: Location::caller(),
})),
}
}

fn custom_error(rejection: &PathRejection) -> responses::error::Error {
match rejection {
axum::extract::rejection::PathRejection::FailedToDeserializePathParams(_) => {
responses::error::Error::from(auth::Error::InvalidKeyFormat {
location: Location::caller(),
})
}
axum::extract::rejection::PathRejection::MissingPathParams(_) => {
responses::error::Error::from(auth::Error::MissingAuthKey {
location: Location::caller(),
})
}
_ => responses::error::Error::from(auth::Error::CannotExtractKeyParam {
location: Location::caller(),
}),
}
}

#[cfg(test)]
mod tests {

use super::parse_key;
use crate::http::axum_implementation::responses::error::Error;

fn assert_error_response(error: &Error, error_message: &str) {
assert!(
error.failure_reason.contains(error_message),
"Error response does not contain message: '{error_message}'. Error: {error:?}"
);
}

#[test]
fn it_should_return_an_authentication_error_if_the_key_cannot_be_parsed() {
let invalid_key = "invalid_key";

let response = parse_key(invalid_key).unwrap_err();

assert_error_response(&response, "Authentication error: Invalid format for authentication key param");
}
}
1 change: 0 additions & 1 deletion src/http/axum_implementation/extractors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod announce_request;
pub mod key;
pub mod peer_ip;
pub mod remote_client_ip;
pub mod scrape_request;
54 changes: 0 additions & 54 deletions src/http/axum_implementation/extractors/peer_ip.rs

This file was deleted.

4 changes: 3 additions & 1 deletion src/http/axum_implementation/extractors/remote_client_ip.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Wrapper for two Axum extractors to get the relevant information
//! to resolve the remote client IP.
use std::net::{IpAddr, SocketAddr};

use axum::async_trait;
Expand All @@ -18,7 +20,7 @@ use serde::{Deserialize, Serialize};
/// `right_most_x_forwarded_for` = 126.0.0.2
/// `connection_info_ip` = 126.0.0.3
///
/// More info about inner extractors :<https://github.com/imbolc/axum-client-ip>
/// More info about inner extractors: <https://github.com/imbolc/axum-client-ip>
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct RemoteClientIp {
pub right_most_x_forwarded_for: Option<IpAddr>,
Expand Down
Loading