Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/apis/v1/context/auth_key/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub async fn delete_auth_key_handler(
) -> Response {
match Key::from_str(&seconds_valid_or_key.0) {
Err(_) => invalid_auth_key_param_response(&seconds_valid_or_key.0),
Ok(key) => match tracker.remove_auth_key(&key.to_string()).await {
Ok(key) => match tracker.remove_auth_key(&key).await {
Ok(_) => ok_response(),
Err(e) => failed_to_delete_key_response(e),
},
Expand Down
16 changes: 5 additions & 11 deletions src/databases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use async_trait::async_trait;

use self::error::Error;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth;
use crate::tracker::auth::{self, Key};

pub(self) struct Builder<T>
where
Expand Down Expand Up @@ -63,25 +63,19 @@ pub trait Database: Sync + Send {

async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>;

// todo: replace type `&str` with `&InfoHash`
async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error>;
async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result<Option<InfoHash>, Error>;

async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

// todo: replace type `&str` with `&Key`
async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error>;
async fn get_key_from_keys(&self, key: &Key) -> Result<Option<auth::ExpiringKey>, Error>;

async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result<usize, Error>;

// todo: replace type `&str` with `&Key`
async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error>;
async fn remove_key_from_keys(&self, key: &Key) -> Result<usize, Error>;

async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result<bool, Error> {
Ok(self
.get_info_hash_from_whitelist(&info_hash.clone().to_string())
.await?
.is_some())
Ok(self.get_info_hash_from_whitelist(info_hash).await?.is_some())
}
}
16 changes: 9 additions & 7 deletions src/databases/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ impl Database for Mysql {
Ok(conn.exec_drop(COMMAND, params! { info_hash_str, completed })?)
}

async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error> {
async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result<Option<InfoHash>, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let select = conn.exec_first::<String, _, _>(
"SELECT info_hash FROM whitelist WHERE info_hash = :info_hash",
params! { info_hash },
params! { "info_hash" => info_hash.to_hex_string() },
)?;

let info_hash = select.map(|f| InfoHash::from_str(&f).expect("Failed to decode InfoHash String from DB!"));
Expand Down Expand Up @@ -183,11 +183,13 @@ impl Database for Mysql {
Ok(1)
}

async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error> {
async fn get_key_from_keys(&self, key: &Key) -> Result<Option<auth::ExpiringKey>, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let query =
conn.exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key });
let query = conn.exec_first::<(String, i64), _, _>(
"SELECT `key`, valid_until FROM `keys` WHERE `key` = :key",
params! { "key" => key.to_string() },
);

let key = query?;

Expand All @@ -211,10 +213,10 @@ impl Database for Mysql {
Ok(1)
}

async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error> {
async fn remove_key_from_keys(&self, key: &Key) -> Result<usize, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { key })?;
conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { "key" => key.to_string() })?;

Ok(1)
}
Expand Down
14 changes: 7 additions & 7 deletions src/databases/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ impl Database for Sqlite {
}
}

async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error> {
async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result<Option<InfoHash>, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?;

let mut rows = stmt.query([info_hash])?;
let mut rows = stmt.query([info_hash.to_hex_string()])?;

let query = rows.next()?;

Expand Down Expand Up @@ -200,7 +200,7 @@ impl Database for Sqlite {
}
}

async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error> {
async fn get_key_from_keys(&self, key: &Key) -> Result<Option<auth::ExpiringKey>, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?;
Expand All @@ -211,9 +211,9 @@ impl Database for Sqlite {

Ok(key.map(|f| {
let expiry: i64 = f.get(1).unwrap();
let id: String = f.get(0).unwrap();
let key: String = f.get(0).unwrap();
auth::ExpiringKey {
key: id.parse::<Key>().unwrap(),
key: key.parse::<Key>().unwrap(),
valid_until: DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs()),
}
}))
Expand All @@ -237,10 +237,10 @@ impl Database for Sqlite {
}
}

async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error> {
async fn remove_key_from_keys(&self, key: &Key) -> Result<usize, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key])?;
let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key.to_string()])?;

if deleted == 1 {
// should only remove a single record.
Expand Down
12 changes: 12 additions & 0 deletions src/protocol/info_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ impl InfoHash {
pub fn bytes(&self) -> [u8; 20] {
self.0
}

#[must_use]
pub fn to_hex_string(&self) -> String {
self.to_string()
}
}

impl std::fmt::Display for InfoHash {
Expand Down Expand Up @@ -197,6 +202,13 @@ mod tests {
assert_eq!(output, "ffffffffffffffffffffffffffffffffffffffff");
}

#[test]
fn an_info_hash_should_return_its_a_40_utf8_lowercased_char_hex_representations_as_string() {
let info_hash = InfoHash::from_str("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF").unwrap();

assert_eq!(info_hash.to_hex_string(), "ffffffffffffffffffffffffffffffffffffffff");
}

#[test]
fn an_info_hash_can_be_created_from_a_valid_20_byte_array_slice() {
let info_hash: InfoHash = [255u8; 20].as_slice().into();
Expand Down
11 changes: 5 additions & 6 deletions src/tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,9 @@ impl Tracker {
/// # Panics
///
/// Will panic if key cannot be converted into a valid `Key`.
pub async fn remove_auth_key(&self, key: &str) -> Result<(), databases::error::Error> {
// todo: change argument `key: &str` to `key: &Key`
pub async fn remove_auth_key(&self, key: &Key) -> Result<(), databases::error::Error> {
self.database.remove_key_from_keys(key).await?;
self.keys.write().await.remove(&key.parse::<Key>().unwrap());
self.keys.write().await.remove(key);
Ok(())
}

Expand Down Expand Up @@ -1175,12 +1174,12 @@ mod tests {
async fn it_should_remove_an_authentication_key() {
let tracker = private_tracker();

let key = tracker.generate_auth_key(Duration::from_secs(100)).await.unwrap();
let expiring_key = tracker.generate_auth_key(Duration::from_secs(100)).await.unwrap();

let result = tracker.remove_auth_key(&key.id().to_string()).await;
let result = tracker.remove_auth_key(&expiring_key.id()).await;

assert!(result.is_ok());
assert!(tracker.verify_auth_key(&key.id()).await.is_err());
assert!(tracker.verify_auth_key(&expiring_key.id()).await.is_err());
}

#[tokio::test]
Expand Down