Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 34 additions & 2 deletions crates/atuin-client/src/api_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ use reqwest::{
use atuin_common::{
api::{
AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest,
ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse,
SyncHistoryResponse,
ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse,
SendVerificationResponse, StatusResponse, SyncHistoryResponse, VerificationTokenRequest,
VerificationTokenResponse,
},
record::RecordStatus,
};
Expand Down Expand Up @@ -403,4 +404,35 @@ impl<'a> Client<'a> {
bail!("Unknown error");
}
}

// Either request a verification email if token is null, or validate a token
pub async fn verify(&self, token: Option<String>) -> Result<(bool, bool)> {
// could dedupe this a bit, but it's simple at the moment
let (email_sent, verified) = if let Some(token) = token {
let url = format!("{}/api/v0/account/verify", self.sync_addr);
let url = Url::parse(url.as_str())?;

let resp = self
.client
.post(url)
.json(&VerificationTokenRequest { token })
.send()
.await?;
let resp = handle_resp_error(resp).await?;
let resp = resp.json::<VerificationTokenResponse>().await?;

(false, resp.verified)
} else {
let url = format!("{}/api/v0/account/send-verification", self.sync_addr);
let url = Url::parse(url.as_str())?;

let resp = self.client.post(url).send().await?;
let resp = handle_resp_error(resp).await?;
let resp = resp.json::<SendVerificationResponse>().await?;

(resp.email_sent, resp.verified)
};

Ok((email_sent, verified))
}
}
2 changes: 2 additions & 0 deletions crates/atuin-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ semver = { workspace = true }
thiserror = { workspace = true }
directories = { workspace = true }
sysinfo = "0.30.7"
base64 = { workspace = true }
getrandom = "0.2"

lazy_static = "1.4.0"

Expand Down
16 changes: 16 additions & 0 deletions crates/atuin-common/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ pub struct RegisterResponse {
#[derive(Debug, Serialize, Deserialize)]
pub struct DeleteUserResponse {}

#[derive(Debug, Serialize, Deserialize)]
pub struct SendVerificationResponse {
pub email_sent: bool,
pub verified: bool,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct VerificationTokenRequest {
pub token: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct VerificationTokenResponse {
pub verified: bool,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ChangePasswordRequest {
pub current_password: String,
Expand Down
33 changes: 30 additions & 3 deletions crates/atuin-common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,30 @@ use std::path::PathBuf;

use eyre::{eyre, Result};

use rand::RngCore;
use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD};
use getrandom::getrandom;
use uuid::Uuid;

pub fn random_bytes<const N: usize>() -> [u8; N] {
/// Generate N random bytes, using a cryptographically secure source
pub fn crypto_random_bytes<const N: usize>() -> [u8; N] {
// rand say they are in principle safe for crypto purposes, but that it is perhaps a better
// idea to use getrandom for things such as passwords.
let mut ret = [0u8; N];

rand::thread_rng().fill_bytes(&mut ret);
getrandom(&mut ret).expect("Failed to generate random bytes!");

ret
}

/// Generate N random bytes using a cryptographically secure source, return encoded as a string
pub fn crypto_random_string<const N: usize>() -> String {
let bytes = crypto_random_bytes::<N>();

// We only use this to create a random string, and won't be reversing it to find the original
// data - no padding is OK there. It may be in URLs.
BASE64_URL_SAFE_NO_PAD.encode(bytes)
}

pub fn uuid_v7() -> Uuid {
Uuid::now_v7()
}
Expand Down Expand Up @@ -178,6 +191,7 @@ impl<T: AsRef<str>> Escapable for T {}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_ne;
use time::Month;

use super::*;
Expand Down Expand Up @@ -292,4 +306,17 @@ mod tests {
Cow::Owned(_)
));
}

#[test]
fn dumb_random_test() {
// Obviously not a test of randomness, but make sure we haven't made some
// catastrophic error

assert_ne!(crypto_random_string::<1>(), crypto_random_string::<1>());
assert_ne!(crypto_random_string::<2>(), crypto_random_string::<2>());
assert_ne!(crypto_random_string::<4>(), crypto_random_string::<4>());
assert_ne!(crypto_random_string::<8>(), crypto_random_string::<8>());
assert_ne!(crypto_random_string::<16>(), crypto_random_string::<16>());
assert_ne!(crypto_random_string::<32>(), crypto_random_string::<32>());
}
}
5 changes: 5 additions & 0 deletions crates/atuin-server-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
async fn get_user(&self, username: &str) -> DbResult<User>;
async fn get_user_session(&self, u: &User) -> DbResult<Session>;
async fn add_user(&self, user: &NewUser) -> DbResult<i64>;

async fn user_verified(&self, id: i64) -> DbResult<bool>;
async fn verify_user(&self, id: i64) -> DbResult<()>;
async fn user_verification_token(&self, id: i64) -> DbResult<String>;

async fn update_user_password(&self, u: &User) -> DbResult<()>;

async fn total_history(&self) -> DbResult<i64>;
Expand Down
1 change: 1 addition & 0 deletions crates/atuin-server-database/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub struct User {
pub username: String,
pub email: String,
pub password: String,
pub verified: Option<OffsetDateTime>,
}

pub struct Session {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
alter table users add verified_at timestamp with time zone default null;

create table user_verification_token(
id bigserial primary key,
user_id bigint unique references users(id),
token text,
valid_until timestamp with time zone
);
99 changes: 91 additions & 8 deletions crates/atuin-server-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Range;

use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
use atuin_common::utils::crypto_random_string;
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
use atuin_server_database::{Database, DbError, DbResult};
use futures_util::TryStreamExt;
Expand All @@ -11,7 +12,7 @@ use sqlx::postgres::PgPoolOptions;
use sqlx::Row;

use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
use tracing::instrument;
use tracing::{instrument, trace};
use uuid::Uuid;
use wrappers::{DbHistory, DbRecord, DbSession, DbUser};

Expand Down Expand Up @@ -100,18 +101,100 @@ impl Database for Postgres {

#[instrument(skip_all)]
async fn get_user(&self, username: &str) -> DbResult<User> {
sqlx::query_as("select id, username, email, password from users where username = $1")
.bind(username)
.fetch_one(&self.pool)
.await
.map_err(fix_error)
.map(|DbUser(user)| user)
sqlx::query_as(
"select id, username, email, password, verified_at from users where username = $1",
)
.bind(username)
.fetch_one(&self.pool)
.await
.map_err(fix_error)
.map(|DbUser(user)| user)
}

#[instrument(skip_all)]
async fn user_verified(&self, id: i64) -> DbResult<bool> {
let res: (bool,) =
sqlx::query_as("select verified_at is not null from users where id = $1")
.bind(id)
.fetch_one(&self.pool)
.await
.map_err(fix_error)?;

Ok(res.0)
}

#[instrument(skip_all)]
async fn verify_user(&self, id: i64) -> DbResult<()> {
sqlx::query(
"update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
)
.bind(id)
.execute(&self.pool)
.await
.map_err(fix_error)?;

Ok(())
}

/// Return a valid verification token for the user
/// If the user does not have any token, create one, insert it, and return
/// If the user has a token, but it's invalid, delete it, create a new one, return
/// If the user already has a valid token, return it
#[instrument(skip_all)]
async fn user_verification_token(&self, id: i64) -> DbResult<String> {
const TOKEN_VALID_MINUTES: i64 = 15;

// First we check if there is a verification token
let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
"select token, valid_until from user_verification_token where user_id = $1",
)
.bind(id)
.fetch_optional(&self.pool)
.await
.map_err(fix_error)?;

let token = if let Some((token, valid_until)) = token {
trace!("Token for user {id} valid until {valid_until}");

// We have a token, AND it's still valid
if valid_until > time::OffsetDateTime::now_utc() {
token
} else {
// token has expired. generate a new one, return it
let token = crypto_random_string::<24>();

sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
.bind(id)
.bind(&token)
.bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
.execute(&self.pool)
.await
.map_err(fix_error)?;

token
}
} else {
// No token in the database! Generate one, insert it
let token = crypto_random_string::<24>();

sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
.bind(id)
.bind(&token)
.bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
.execute(&self.pool)
.await
.map_err(fix_error)?;

token
};

Ok(token)
}

#[instrument(skip_all)]
async fn get_session_user(&self, token: &str) -> DbResult<User> {
sqlx::query_as(
"select users.id, users.username, users.email, users.password from users
"select users.id, users.username, users.email, users.password, users.verified_at from users
inner join sessions
on users.id = sessions.user_id
and sessions.token = $1",
Expand Down
1 change: 1 addition & 0 deletions crates/atuin-server-postgres/src/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ impl<'a> FromRow<'a, PgRow> for DbUser {
username: row.try_get("username")?,
email: row.try_get("email")?,
password: row.try_get("password")?,
verified: row.try_get("verified_at")?,
}))
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/atuin-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ argon2 = "0.5"
semver = { workspace = true }
metrics-exporter-prometheus = "0.12.1"
metrics = "0.21.1"
postmark = {version= "0.10.0", features=["reqwest"]}
Loading