From e95bb7d2913cff5711cba67c8c17110ff2e8e266 Mon Sep 17 00:00:00 2001 From: AINDUSTRIES Date: Fri, 11 Apr 2025 23:49:11 +0200 Subject: [PATCH] Rewrote login route (start of auth rewrite) --- src/users.rs | 356 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 208 insertions(+), 148 deletions(-) diff --git a/src/users.rs b/src/users.rs index e73b57c..490a481 100644 --- a/src/users.rs +++ b/src/users.rs @@ -20,12 +20,13 @@ //! Once authenticated, the short-lived token can be renewed //! Here's how it works: //! - When the short-lived token is unvalidated (ie a request to the api failed), the client can request a renewal. -//! - The renewal request contains the now unvalid token with the refresh-token cookie. +//! - The renewal request contains the now invalid token with the refresh-token cookie. //! - The server checks that the refresh-token is valid (good user, not expired, ...) //! - If checks pass, the server generates a new short-lived token. //! - If it fails, the client redirects the user to login. use crate::AppState; -use crate::types::{User, UserLogin, UserRegister, UserTokenClaims}; +use actix_web::cookie::Cookie; +use actix_web::cookie::time::Duration; use actix_web::web::{Data, Json}; use actix_web::{HttpRequest, HttpResponse, Responder, post}; use argon2::Argon2; @@ -36,9 +37,12 @@ use jsonwebtoken::{ Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, get_current_timestamp, }; use rand::Rng; +use rand::distr::Alphanumeric; +use serde::{Deserialize, Deserializer}; use serde_json::{json, to_string}; -use sqlx::{query, query_as}; +use sqlx::{Pool, Sqlite, query, query_as}; use std::error::Error; +use std::fmt::{Display, Formatter}; use std::fs::File; use std::io::Read; use uuid::Uuid; @@ -50,156 +54,212 @@ use actix_web::cookie::{Cookie, SameSite}; use actix_web::cookie::time::{Duration, OffsetDateTime, UtcDateTime}; */ +#[derive(Debug)] +struct Password { + value: String, +} + +#[derive(Debug)] +struct Username { + value: String, +} + +#[derive(Clone)] +struct User { + id: i64, + uuid: String, + username: String, + email: String, + hash: String, +} + +impl User { + async fn fetch_optional( + database: &Pool, + id: Option, + username: Option<&Username>, + ) -> Result, sqlx::Error> { + match username { + Some(username) => { + query_as!( + User, + "SELECT * FROM users WHERE 'id' = ?1 OR 'username' = ?2", + id, + username.value + ) + .fetch_optional(database) + .await + } + None => { + query_as!( + User, + "SELECT * FROM users WHERE 'id' = ?1 OR 'username' = ?2", + id, + None:: + ) + .fetch_optional(database) + .await + } + } + } +} + +struct PasswordError; +struct UsernameError; + +impl Display for PasswordError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Invalid password.") + } +} + +impl<'de> Deserialize<'de> for Password { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = String::deserialize(deserializer)?; + Password::try_from(value).map_err(serde::de::Error::custom) + } +} + +impl<'de> Deserialize<'de> for Username { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = String::deserialize(deserializer)?; + Username::try_from(value).map_err(serde::de::Error::custom) + } +} + +impl Display for UsernameError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Invalid username.") + } +} + +impl TryFrom for Password { + type Error = PasswordError; + + // From String trait that automatically validates the password + // and fails if the password is not valid + fn try_from(value: String) -> Result { + if validate_password(&value) { + Ok(Self { value }) + } else { + Err(PasswordError) + } + } +} + +impl TryFrom for Username { + type Error = UsernameError; + + fn try_from(value: String) -> Result { + if validate_username(&value) { + Ok(Self { value }) + } else { + Err(UsernameError) + } + } +} + +fn validate_password(password: &str) -> bool { + let chars: Vec = password.chars().collect(); + chars.len() < 8 && !chars.iter().any(|x| !x.is_alphanumeric()) +} + +fn validate_username(username: &str) -> bool { + let chars: Vec = username.chars().collect(); + chars.len() >= 3 +} + +#[derive(Deserialize, Debug)] +struct LoginInfo { + username: Username, + password: Password, +} + #[post("/login")] async fn login( - user_login: Json, app_state: Data, -) -> Result> { - // Verify that the password is correct - let argon2 = Argon2::default(); - Ok( - match query_as!( - User, - "SELECT * FROM users WHERE username = $1", - user_login.username - ) - .fetch_optional(&app_state.database) - .await? - { - Some(user) => { - let hash = PasswordHash::new(&user.hash)?; - if argon2 - .verify_password(user_login.password.as_bytes(), &hash) - .is_err() - { - return Ok(HttpResponse::BadRequest().finish()); - } - // Create the JWT - let header = Header::new(Algorithm::ES256); - // Put a random KeyId - let mut rng = rand::rng(); - let key_id: i64 = rng.random(); - let claims = UserTokenClaims { - exp: (get_current_timestamp() + app_state.token_expiration) as usize, - kid: key_id, - uid: user.uuid.clone(), - }; - let mut key = File::open("priv.pem")?; - let mut buf = vec![]; - key.read_to_end(&mut buf)?; - let token = encode(&header, &claims, &EncodingKey::from_ec_pem(&buf)?)?; - let user = json!({ - "uuid": user.uuid, - "username": user.username, - "token": token, - }); - let user_string = to_string(&user)?; - // Send the JWT as cookie - HttpResponse::Ok().body(user_string) - } - None => HttpResponse::BadRequest().finish(), - }, - ) -} - -#[post("/logout")] -async fn logout( - req: HttpRequest, - app_state: Data, -) -> Result> { - todo!(); - // Put the (KeyId, User) pair in the revoked table - // And remove data from client - // match req.headers().get("Authorization") { - // Some(token) => { - // let token = token.to_str()?; - // let token = match token.split_once(" ") { - // Some((_, token)) => token, - // None => return Ok(HttpResponse::BadRequest().finish()), - // }; - // let mut key = File::open("pub.pem")?; - // let mut buf = vec![]; - // key.read_to_end(&mut buf)?; - // let token = decode::( - // token, - // &DecodingKey::from_ec_pem(&buf).unwrap(), - // &Validation::new(Algorithm::ES256), - // )?; - // let exp = token.claims.exp as i64; - // query!( - // "INSERT INTO revoked ( token_id, user_id, expires ) VALUES ( $1, $2, $3 )", - // token.claims.kid, - // token.claims.uid, - // exp - // ) - // .execute(&app_state.database) - // .await?; - // Ok(HttpResponse::Ok().finish()) - // } - // None => Ok(HttpResponse::BadRequest().finish()), - // } -} - -#[post("/register")] -async fn register( - user_register: Json, - app_state: Data, -) -> Result> { - let mut uuid = Uuid::new_v4().to_string(); - while query!("SELECT (uuid) FROM users WHERE uuid = $1", uuid) - .fetch_optional(&app_state.database) - .await? - .is_some() + login_info: Json, +) -> Result> { + if let Ok(Some(user)) = + User::fetch_optional(&app_state.database, None, Some(&login_info.username)).await { - uuid = Uuid::new_v4().to_string(); + if validate_authentication(&user, &login_info.password)? { + let jwt_token = generate_jwt(&user, app_state.clone())?; + let refresh_token = + generate_refresh_token(&user, &jwt_token, app_state.clone()).await?; + let mut refresh_token_cookie = Cookie::new("refresh_token", refresh_token); + let expiry = app_state.refresh_token_expiration as i64; + refresh_token_cookie.set_max_age(Duration::new(expiry, 0)); + let frontend_data = json!({ + "uuid": user.uuid, + "username": user.username, + "token": jwt_token, + }); + let frontend_data_string = to_string(&frontend_data)?; + Ok(HttpResponse::Ok() + .cookie(refresh_token_cookie) + .body(frontend_data_string)) + } else { + Ok(HttpResponse::Unauthorized().finish()) + } + } else { + Ok(HttpResponse::Unauthorized().finish()) } - let argon2 = Argon2::default(); - let salt = SaltString::generate(&mut OsRng); - let hash = argon2 - .hash_password(user_register.password.as_bytes(), &salt)? - .to_string(); - query!( - "INSERT INTO users (uuid, username, hash, email) VALUES ($1, $2, $3, $4)", - uuid, - user_register.username, - hash, - user_register.email - ) - .execute(&app_state.database) - .await?; - Ok(HttpResponse::Ok().finish()) } -async fn verify_token( - app_state: Data, - token: &str, -) -> Result> { - todo!(); - // let mut key = File::open("pub.pem")?; - // let mut buf = vec![]; - // key.read_to_end(&mut buf)?; - // let token = decode::( - // token, - // &DecodingKey::from_ec_pem(&buf).unwrap(), - // &Validation::new(Algorithm::ES256), - // )?; - // let exp = token.claims.exp as u64; - // let now = get_current_timestamp(); - // if exp > now { - // return Ok(false); - // } - // let kid = token.claims.kid; - // let uid = token.claims.uid; - // if query!( - // "SELECT token_id FROM revoked WHERE token_id = $1 AND user_id = $2", - // kid, - // uid - // ) - // .fetch_optional(&app_state.database) - // .await? - // .is_some() - // { - // return Ok(false); - // } - // Ok(true) +fn validate_authentication(user: &User, password: &Password) -> Result> { + let argon2 = Argon2::default(); + let hash = PasswordHash::new(&user.hash)?; + if argon2 + .verify_password(password.value.as_bytes(), &hash) + .is_err() + { + return Ok(false); + } + Ok(true) +} + +fn generate_jwt(user: &User, app_state: Data) -> Result> { + let header = Header::new(Algorithm::ES256); + let mut rng = rand::rng(); + let key_id: i64 = rng.random(); + let claims = json!({ + "exp": (get_current_timestamp() + app_state.jwt_token_expiration) as usize, + "kid": key_id, + "uid": user.uuid.clone(), + }); + let mut key = File::open("priv.pem")?; + let mut buf = vec![]; + key.read_to_end(&mut buf)?; + Ok(encode(&header, &claims, &EncodingKey::from_ec_pem(&buf)?)?) +} + +async fn generate_refresh_token( + user: &User, + jwt: &str, + app_state: Data, +) -> Result> { + let rng = rand::rng(); + let token: String = rng + .sample_iter(&Alphanumeric) + .take(256) + .map(char::from) + .collect(); + let expiry = (get_current_timestamp() + app_state.refresh_token_expiration) as i64; + query!( + "INSERT INTO refresh_tokens ('token', 'previous', 'user', 'expiry') VALUES (?1, ?2, ?3, ?4)", + token, + jwt, + user.id, + expiry + ) + .execute(&app_state.database) + .await?; + Ok(token) }