Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cSpell.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"uroot",
"Vagaa",
"Vuze",
"whitespaces",
"Xtorrent",
"Xunlei",
"xxxxxxxxxxxxxxxxxxxxd",
Expand Down
159 changes: 159 additions & 0 deletions src/http/axum_implementation/extractors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use std::panic::Location;
use std::str::FromStr;

use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
use thiserror::Error;

use super::query::Query;
use crate::http::percent_encoding::{percent_decode_info_hash, percent_decode_peer_id};
use crate::protocol::info_hash::{ConversionError, InfoHash};
use crate::tracker::peer::{self, IdConversionError};

pub struct ExtractAnnounceParams(pub AnnounceParams);

#[derive(Debug, PartialEq)]
pub struct AnnounceParams {
pub info_hash: InfoHash,
pub peer_id: peer::Id,
pub port: u16,
}

#[derive(Error, Debug)]
pub enum ParseAnnounceQueryError {
#[error("missing infohash {location}")]
MissingInfoHash { location: &'static Location<'static> },
#[error("invalid infohash {location}")]
InvalidInfoHash { location: &'static Location<'static> },
#[error("missing peer id {location}")]
MissingPeerId { location: &'static Location<'static> },
#[error("invalid peer id {location}")]
InvalidPeerId { location: &'static Location<'static> },
#[error("missing port {location}")]
MissingPort { location: &'static Location<'static> },
#[error("invalid port {location}")]
InvalidPort { location: &'static Location<'static> },
}

impl From<IdConversionError> for ParseAnnounceQueryError {
#[track_caller]
fn from(_err: IdConversionError) -> Self {
Self::InvalidPeerId {
location: Location::caller(),
}
}
}

impl From<ConversionError> for ParseAnnounceQueryError {
#[track_caller]
fn from(_err: ConversionError) -> Self {
Self::InvalidPeerId {
location: Location::caller(),
}
}
}

impl TryFrom<Query> for AnnounceParams {
type Error = ParseAnnounceQueryError;

fn try_from(query: Query) -> Result<Self, Self::Error> {
Ok(Self {
info_hash: extract_info_hash(&query)?,
peer_id: extract_peer_id(&query)?,
port: extract_port(&query)?,
})
}
}

fn extract_info_hash(query: &Query) -> Result<InfoHash, ParseAnnounceQueryError> {
match query.get_param("info_hash") {
Some(raw_info_hash) => Ok(percent_decode_info_hash(&raw_info_hash)?),
None => {
return Err(ParseAnnounceQueryError::MissingInfoHash {
location: Location::caller(),
})
}
}
}

fn extract_peer_id(query: &Query) -> Result<peer::Id, ParseAnnounceQueryError> {
match query.get_param("peer_id") {
Some(raw_peer_id) => Ok(percent_decode_peer_id(&raw_peer_id)?),
None => {
return Err(ParseAnnounceQueryError::MissingPeerId {
location: Location::caller(),
})
}
}
}

fn extract_port(query: &Query) -> Result<u16, ParseAnnounceQueryError> {
match query.get_param("port") {
Some(raw_port) => Ok(u16::from_str(&raw_port).map_err(|_e| ParseAnnounceQueryError::InvalidPort {
location: Location::caller(),
})?),
None => {
return Err(ParseAnnounceQueryError::MissingPort {
location: Location::caller(),
})
}
}
}

#[async_trait]
impl<S> FromRequestParts<S> for ExtractAnnounceParams
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let raw_query = parts.uri.query();

if raw_query.is_none() {
return Err((StatusCode::BAD_REQUEST, "missing query params"));
}

let query = raw_query.unwrap().parse::<Query>();

if query.is_err() {
return Err((StatusCode::BAD_REQUEST, "can't parse query params"));
}

let announce_params = AnnounceParams::try_from(query.unwrap());

if announce_params.is_err() {
return Err((StatusCode::BAD_REQUEST, "can't parse query params for announce request"));
}

Ok(ExtractAnnounceParams(announce_params.unwrap()))
}
}

#[cfg(test)]
mod tests {
use super::AnnounceParams;
use crate::http::axum_implementation::query::Query;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::peer;

#[test]
fn announce_request_params_should_be_extracted_from_url_query_params() {
let raw_query = "info_hash=%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0&peer_id=-qB00000000000000001&port=17548";

let query = raw_query.parse::<Query>().unwrap();

let announce_params = AnnounceParams::try_from(query).unwrap();

assert_eq!(
announce_params,
AnnounceParams {
info_hash: "3b245504cf5f11bbdbe1201cea6a6bf45aee1bc0".parse::<InfoHash>().unwrap(),
peer_id: "-qB00000000000000001".parse::<peer::Id>().unwrap(),
port: 17548,
}
);
}
}
25 changes: 25 additions & 0 deletions src/http/axum_implementation/handlers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::sync::Arc;

use axum::extract::State;
use axum::response::Json;

use super::extractors::ExtractAnnounceParams;
use super::resources::ok::Ok;
use super::responses::ok_response;
use crate::tracker::Tracker;

#[allow(clippy::unused_async)]
pub async fn get_status_handler() -> Json<Ok> {
ok_response()
}

/// # Panics
///
/// todo
#[allow(clippy::unused_async)]
pub async fn announce_handler(
State(_tracker): State<Arc<Tracker>>,
ExtractAnnounceParams(_announce_params): ExtractAnnounceParams,
) -> Json<Ok> {
todo!()
}
7 changes: 7 additions & 0 deletions src/http/axum_implementation/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub mod extractors;
pub mod handlers;
pub mod query;
pub mod resources;
pub mod responses;
pub mod routes;
pub mod server;
138 changes: 138 additions & 0 deletions src/http/axum_implementation/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use std::collections::HashMap;
use std::panic::Location;
use std::str::FromStr;

use thiserror::Error;
pub struct Query {
params: HashMap<String, String>,
}

#[derive(Error, Debug)]
pub enum ParseQueryError {
#[error("invalid param {raw_param} in {location}")]
InvalidParam {
location: &'static Location<'static>,
raw_param: String,
},
}

impl FromStr for Query {
type Err = ParseQueryError;

fn from_str(raw_query: &str) -> Result<Self, Self::Err> {
let mut params: HashMap<String, String> = HashMap::new();

let raw_params = raw_query.trim().trim_start_matches('?').split('&').collect::<Vec<&str>>();

for raw_param in raw_params {
let param: Param = raw_param.parse()?;
params.insert(param.name, param.value);
}

Ok(Self { params })
}
}

#[derive(Debug, PartialEq)]
struct Param {
name: String,
value: String,
}

impl FromStr for Param {
type Err = ParseQueryError;

fn from_str(raw_param: &str) -> Result<Self, Self::Err> {
let pair = raw_param.split('=').collect::<Vec<&str>>();

if pair.len() > 2 {
return Err(ParseQueryError::InvalidParam {
location: Location::caller(),
raw_param: raw_param.to_owned(),
});
}

Ok(Self {
name: pair[0].to_owned(),
value: pair[1].to_owned(),
})
}
}

impl Query {
#[must_use]
pub fn get_param(&self, name: &str) -> Option<String> {
self.params.get(name).map(std::string::ToString::to_string)
}
}

#[cfg(test)]
mod tests {
use super::Query;
use crate::http::axum_implementation::query::Param;

#[test]
fn it_should_parse_the_query_params_from_an_url_query_string() {
let raw_query = "info_hash=%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0&peer_id=-qB00000000000000001&port=17548";

let query = raw_query.parse::<Query>().unwrap();

assert_eq!(
query.get_param("info_hash").unwrap(),
"%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0"
);
assert_eq!(query.get_param("peer_id").unwrap(), "-qB00000000000000001");
assert_eq!(query.get_param("port").unwrap(), "17548");
}

#[test]
fn it_should_fail_parsing_an_invalid_query_string() {
let invalid_raw_query = "name=value=value";

let query = invalid_raw_query.parse::<Query>();

assert!(query.is_err());
}

#[test]
fn it_should_ignore_the_preceding_question_mark_if_it_exists() {
let raw_query = "?name=value";

let query = raw_query.parse::<Query>().unwrap();

assert_eq!(query.get_param("name").unwrap(), "value");
}

#[test]
fn it_should_trim_whitespaces() {
let raw_query = " name=value ";

let query = raw_query.parse::<Query>().unwrap();

assert_eq!(query.get_param("name").unwrap(), "value");
}

#[test]
fn it_should_parse_a_single_query_param() {
let raw_param = "name=value";

let param = raw_param.parse::<Param>().unwrap();

assert_eq!(
param,
Param {
name: "name".to_string(),
value: "value".to_string(),
}
);
}

#[test]
fn it_should_fail_parsing_an_invalid_query_param() {
let invalid_raw_param = "name=value=value";

let query = invalid_raw_param.parse::<Param>();

assert!(query.is_err());
}
}
1 change: 1 addition & 0 deletions src/http/axum_implementation/resources/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod ok;
4 changes: 4 additions & 0 deletions src/http/axum_implementation/resources/ok.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct Ok {}
10 changes: 10 additions & 0 deletions src/http/axum_implementation/responses.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Resource responses

use axum::Json;

use super::resources::ok::Ok;

#[must_use]
pub fn ok_response() -> Json<Ok> {
Json(Ok {})
}
15 changes: 15 additions & 0 deletions src/http/axum_implementation/routes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use std::sync::Arc;

use axum::routing::get;
use axum::Router;

use super::handlers::{announce_handler, get_status_handler};
use crate::tracker::Tracker;

pub fn router(tracker: &Arc<Tracker>) -> Router {
Router::new()
// Status
.route("/status", get(get_status_handler))
// Announce request
.route("/announce", get(announce_handler).with_state(tracker.clone()))
}
Loading