diff --git a/src/main.rs b/src/main.rs index 1f1f100..afcf5fb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,22 @@ use argon2::{ - password_hash::{ - rand_core::OsRng, - PasswordHash, PasswordHasher, PasswordVerifier, SaltString, - }, + password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Argon2, }; -use bytes::Buf; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use chrono::{DateTime, Days, Utc}; #[cfg(target_os = "linux")] use daemonize::Daemonize; -use futures; -use http_body_util::{BodyExt, Empty, Full}; -use hyper::body::{Body, Frame, Incoming, SizeHint}; -use hyper::header::{LOCATION, SET_COOKIE, COOKIE}; -use hyper::server::conn::http1; -use hyper::service::service_fn; -use hyper::{Error, Method, Request, Response, StatusCode}; +use http_body_util::{BodyExt, Full}; +use hyper::{ + body::{Body as HyperBody, Incoming, Frame}, + header::{COOKIE, SET_COOKIE}, + server::conn::http1, + service::service_fn, + Error, Method, Request, Response, StatusCode, +}; use hyper_util::rt::{TokioIo, TokioTimer}; use rand::distributions::{Alphanumeric, DistString}; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::{from_reader, Value}; use sqlx::sqlite::SqlitePool; use std::collections::HashMap; @@ -33,27 +30,32 @@ use std::str::FromStr; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use std::time::SystemTime; -use hyper::http::HeaderValue; -use serde::de::DeserializeOwned; use tokio::net::TcpListener; -// Some functions could return an empty body, will try using the following enum: -enum ResponseBody { +enum Body { Full(Full), Empty, } -impl Body for ResponseBody { +impl Body { + fn new(data: T) -> Self + where + Bytes: From, + { + Body::Full(Full::new(Bytes::from(data))) + } +} + +impl HyperBody for Body { type Data = Bytes; type Error = Infallible; - fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll, Self::Error>>> { + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { match &mut *self.get_mut() { - Self::Full(incoming) => { - Pin::new(incoming).poll_frame(cx) - }, - Self::Empty => { - Poll::Ready(None) - } + Self::Full(incoming) => Pin::new(incoming).poll_frame(cx), + Self::Empty => Poll::Ready(None), } } } @@ -93,15 +95,21 @@ struct Settings { bind_address: String, } -async fn service(req: Request, db: Arc>) -> Result, Error> { +async fn service( + req: Request, + db: Arc>, +) -> Result, Error> { match req.method() { - &Method::GET => { get(req, db).await } - &Method::POST => { post(req, db).await } - _ => { Ok(Response::builder().status(StatusCode::IM_A_TEAPOT).body(ResponseBody::Empty).unwrap()) } + &Method::GET => get(req, db).await, + &Method::POST => post(req, db).await, + _ => Ok(Response::builder() + .status(StatusCode::IM_A_TEAPOT) + .body(Body::Empty) + .unwrap()), } } -async fn get(req: Request, db: Arc>) -> Result, Error> { +async fn get(req: Request, db: Arc>) -> Result, Error> { let path = req.uri().path(); if path.starts_with("/static") { get_file(path).await @@ -114,25 +122,38 @@ async fn get(req: Request, db: Arc>) -> Result, path: &str, db: Arc>) -> Result, Error> { - let mut routes = env::current_dir().expect("Could not get app directory (Required to get routes)"); +async fn get_page( + req: &Request, + path: &str, + db: Arc>, +) -> Result, Error> { + let mut routes = + env::current_dir().expect("Could not get app directory (Required to get routes)"); routes.push("routes.json"); let file = File::open(routes).expect("Could not open routes file."); let map: Value = from_reader(file).expect("Could not parse routes, please verify syntax."); match map.get(path) { Some(Value::Object(s)) => { - let perm = is_authorised(req, db).await; - if s.get("permission").unwrap().as_i64().unwrap() <= perm { + let authorised = is_authorised(req, db, s.get("permission").unwrap().as_u64().unwrap() as u8).await; + if authorised { get_file(s.get("file").unwrap().as_str().unwrap()).await } else { - get_file(map.get("/unauthorised").unwrap().get("file").unwrap().as_str().unwrap()).await + get_file( + map.get("/unauthorised") + .unwrap() + .get("file") + .unwrap() + .as_str() + .unwrap(), + ) + .await } - }, - _ => not_found().await + } + _ => not_found().await, } } -async fn get_file(mut path: &str) -> Result, Error> { +async fn get_file(mut path: &str) -> Result, Error> { let mut file_path = env::current_dir().expect("Could not get app directory."); if path.starts_with(r"/") { path = path.strip_prefix(r"/").unwrap(); @@ -147,29 +168,52 @@ async fn get_file(mut path: &str) -> Result, Error> { "js" => "text/javascript", "html" => "text/html", "css" => "text/css", - _ => "" + _ => "", }; - Ok(Response::builder().header("content-type", content_type).body(ResponseBody::Full(Full::new(Bytes::from(buf)))).unwrap()) + Ok(Response::builder() + .header("content-type", content_type) + .body(Body::new(buf)) + .unwrap()) } - Err(_) => not_found().await + Err(_) => not_found().await, } } -async fn get_data(path: &str, req: &Request, db: Arc>) -> Result, Error> { +async fn get_data( + path: &str, + req: &Request, + db: Arc>, +) -> Result, Error> { let pool = db.clone().lock().unwrap().clone(); match path { "/data/players" => { - let items = sqlx::query!(r#"SELECT * FROM players"#).fetch_all(&pool).await.unwrap(); - let players: Vec = items.iter().map(|x| Player { id: x.id, name: x.name.clone() }).collect(); - Ok(Response::new(ResponseBody::Full(Full::new(Bytes::from(serde_json::to_string(&players).unwrap()))))) + let items = sqlx::query!(r#"SELECT * FROM players"#) + .fetch_all(&pool) + .await + .unwrap(); + let players: Vec = items + .iter() + .map(|x| Player { + id: x.id, + name: x.name.clone(), + }) + .collect(); + Ok(Response::new(Body::new( + serde_json::to_string(&players).unwrap(), + ))) } "/data/votes" => { let votes = get_votes(req, db).await; - Ok(Response::new(ResponseBody::Full(Full::new(Bytes::from(serde_json::to_string(&votes).unwrap()))))) + Ok(Response::new(Body::new( + serde_json::to_string(&votes).unwrap(), + ))) } "/data/results" => { let votes = get_votes(req, db).await; - let ids: Vec<(i64, i64)> = votes.iter().map(|x| (x.plus_player_id, x.minus_player_id)).collect(); + let ids: Vec<(i64, i64)> = votes + .iter() + .map(|x| (x.plus_player_id, x.minus_player_id)) + .collect(); let mut plus_results: HashMap = HashMap::new(); let mut minus_results: HashMap = HashMap::new(); @@ -191,14 +235,16 @@ async fn get_data(path: &str, req: &Request, db: Arc let mut plus_results: Vec<(i64, i64)> = plus_results.into_iter().collect(); let mut minus_results: Vec<(i64, i64)> = minus_results.into_iter().collect(); - plus_results.sort_by(|a, b| { b.1.cmp(&a.1) }); - minus_results.sort_by(|a, b| { b.1.cmp(&a.1) }); + plus_results.sort_by(|a, b| b.1.cmp(&a.1)); + minus_results.sort_by(|a, b| b.1.cmp(&a.1)); let sorted_results = vec![plus_results, minus_results]; - Ok(Response::new(ResponseBody::Full(Full::new(Bytes::from(serde_json::to_string(&sorted_results).unwrap()))))) + Ok(Response::new(Body::new( + serde_json::to_string(&sorted_results).unwrap(), + ))) } - _ => not_found().await + _ => not_found().await, } } @@ -215,26 +261,39 @@ async fn get_votes(req: &Request, db: Arc>) -> Vec Some(DateTime::from(SystemTime::now())) + None => Some(DateTime::from(SystemTime::now())), }; if date.is_none() { return Vec::new(); } let formatted_date = format!("{}", date.unwrap().format("%d/%m/%Y")); - let items = sqlx::query!(r#"SELECT * FROM votes WHERE submit_date = ?1 ORDER BY id"#, formatted_date).fetch_all(&pool).await.unwrap(); - items.iter().map(|x| Vote { - plus_player_id: x.plus_player_id, - plus_nickname: x.plus_nickname.clone(), - plus_reason: x.plus_reason.clone(), - minus_player_id: x.minus_player_id, - minus_nickname: x.minus_nickname.clone(), - minus_reason: x.minus_reason.clone(), - }).collect() + let items = sqlx::query!( + r#"SELECT * FROM votes WHERE submit_date = ?1 ORDER BY id"#, + formatted_date + ) + .fetch_all(&pool) + .await + .unwrap(); + items + .iter() + .map(|x| Vote { + plus_player_id: x.plus_player_id, + plus_nickname: x.plus_nickname.clone(), + plus_reason: x.plus_reason.clone(), + minus_player_id: x.minus_player_id, + minus_nickname: x.minus_nickname.clone(), + minus_reason: x.minus_reason.clone(), + }) + .collect() } -async fn get_admin(req: &Request, path: &str, db: Arc>) -> Result, Error> { - let perm = is_authorised(req, db.clone()).await; - if perm < 3 { +async fn get_admin( + req: &Request, + path: &str, + db: Arc>, +) -> Result, Error> { + let authorised = is_authorised(req, db.clone(), 3).await; + if authorised { return not_found().await; } if path == "/admin" { @@ -242,43 +301,45 @@ async fn get_admin(req: &Request, path: &str, db: Arc = users.iter().map(|x| (x.username.clone(), x.permissions)).collect(); + let users = sqlx::query!(r#"SELECT username, permissions FROM users"#) + .fetch_all(&pool) + .await + .unwrap(); + let users: Vec<(String, i64)> = users + .iter() + .map(|x| (x.username.clone(), x.permissions)) + .collect(); let stringed = serde_json::to_string(&users).unwrap_or("".to_string()); - return Ok(Response::builder().body(ResponseBody::Full(Full::new(Bytes::from(stringed)))).unwrap()); + return Ok(Response::builder().body(Body::new(stringed)).unwrap()); } not_found().await } -async fn post(req: Request, db: Arc>) -> Result, Error> { +async fn post(req: Request, db: Arc>) -> Result, Error> { let path = req.uri().path(); if path.starts_with("/admin") { return post_admin(req, db).await; } match path { - "/vote" => { - post_vote(req, db).await - } - "/login" => { - login(req, db).await - } - "/register" => { - register(req, db).await - } - "/logout" => { - logout().await - } - _ => { - not_found().await - } + "/vote" => post_vote(req, db).await, + "/login" => login(req, db).await, + "/register" => register(req, db).await, + "/logout" => logout().await, + _ => not_found().await, } } -async fn post_vote(req: Request, db: Arc>) -> Result, Error> { +async fn post_vote( + req: Request, + db: Arc>, +) -> Result, Error> { let body = req.into_body().collect().await?; let data: Result = from_reader(body.aggregate().reader()); if data.is_err() { - return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()); } let vote = data.unwrap(); let timestamp: DateTime = DateTime::from(SystemTime::now()); @@ -295,40 +356,57 @@ async fn post_vote(req: Request, db: Arc>) -> Result vote.minus_reason, formatted).execute(&mut *conn).await; if result.is_err() { - return Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(ResponseBody::Empty).unwrap()); + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::Empty) + .unwrap()); } - Ok(Response::builder().body(ResponseBody::Full(Full::new(Bytes::new()))).unwrap()) + Ok(Response::builder().body(Body::Empty).unwrap()) } -async fn post_admin(req: Request, db: Arc>) -> Result, Error> { - let perm = is_authorised(&req, db.clone()).await; - if perm < 3 { +async fn post_admin( + req: Request, + db: Arc>, +) -> Result, Error> { + let authorised = is_authorised(&req, db.clone(), 3).await; + if authorised { return get_page(&req, "/unauthorised", db).await; } let path = req.uri().path(); match path { "/admin/post/user" => { req_json::(req).await; - }, - "/admin/post/vote" => {}, - "/admin/post/player" => {}, + } + "/admin/post/vote" => {} + "/admin/post/player" => {} _ => {} } not_found().await } -async fn login(req: Request, db: Arc>) -> Result, Error> { +async fn login( + req: Request, + db: Arc>, +) -> Result, Error> { let body = req.into_body().collect().await; let data: Result = from_reader(body?.aggregate().reader()); if data.is_err() { - return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()); } let data = data.unwrap(); if !check_username(&data.username) { - return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Full(Full::new(Bytes::from("Bad Request")))).unwrap()); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()); } let pool = db.clone().lock().unwrap().clone(); - let result = sqlx::query!(r#"SELECT * FROM users WHERE username=?1"#, data.username).fetch_optional(&pool).await; + let result = sqlx::query!(r#"SELECT * FROM users WHERE username=?1"#, data.username) + .fetch_optional(&pool) + .await; match result { Ok(Some(user)) => { let argon = Argon2::default(); @@ -339,81 +417,139 @@ async fn login(req: Request, db: Arc>) -> Result { - Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()) - } + Err(_) => Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()), } } - Ok(None) => { - Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()) - } - Err(_) => { Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(ResponseBody::Empty).unwrap()) } + Ok(None) => Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()), + Err(_) => Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::Empty) + .unwrap()), } } -async fn register(req: Request, db: Arc>) -> Result, Error> { +async fn register( + req: Request, + db: Arc>, +) -> Result, Error> { match req_json::(req).await { - None => Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()), + None => Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()), Some(login) => { if !check_username(&login.username) { - return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()); } if !check_password(&login.password) { - return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()); } let pool = db.clone().lock().unwrap().clone(); let mut conn = pool.acquire().await.unwrap(); - let exists = sqlx::query!(r#"SELECT id FROM users WHERE username=?1"#, login.username).fetch_optional(&mut *conn).await; + let exists = sqlx::query!(r#"SELECT id FROM users WHERE username=?1"#, login.username) + .fetch_optional(&mut *conn) + .await; if exists.unwrap().is_some() { - return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::Empty) + .unwrap()); } let argon2 = Argon2::default(); - let hash = argon2.hash_password(login.password.as_bytes(), &SaltString::generate(&mut OsRng)).unwrap().to_string(); + let hash = argon2 + .hash_password(login.password.as_bytes(), &SaltString::generate(&mut OsRng)) + .unwrap() + .to_string(); let token = Alphanumeric.sample_string(&mut OsRng, 256); let result = sqlx::query!(r#"INSERT INTO users ( username, saltyhash, permissions, token) VALUES ( ?1, ?2, ?3, ?4 )"#, login.username, hash, 0, token).execute(&mut *conn).await; match result { - Ok(_) => Ok(Response::builder().body(ResponseBody::Full(Full::new(Bytes::new()))).unwrap()), - Err(_) => { Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(ResponseBody::Empty).unwrap()) } + Ok(_) => Ok(Response::builder().body(Body::Empty).unwrap()), + Err(_) => Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::Empty) + .unwrap()), } - }, + } } } -async fn logout() -> Result, Error> { +async fn logout() -> Result, Error> { let date: DateTime = DateTime::from(SystemTime::now()); Ok(Response::builder() - //.status(StatusCode::SEE_OTHER) - //.header(LOCATION, "/") - .header(SET_COOKIE, format!("token=''; Expires={}; Secure; HttpOnly; SameSite=Strict", date.to_rfc2822())) - .header(SET_COOKIE, format!("logged=false; Expires={}; Secure; HttpOnly; SameSite=Strict", date.to_rfc2822())) - .body(ResponseBody::Empty).unwrap()) + .header( + SET_COOKIE, + format!( + "token=''; Expires={}; Secure; HttpOnly; SameSite=Strict", + date.to_rfc2822() + ), + ) + .header( + SET_COOKIE, + format!( + "logged=false; Expires={}; Secure; HttpOnly; SameSite=Strict", + date.to_rfc2822() + ), + ) + .body(Body::Empty) + .unwrap()) } -async fn is_authorised(req: &Request, db: Arc>) -> i64 { +async fn is_authorised(req: &Request, db: Arc>, level: u8) -> bool { let cookies = req.headers().get(COOKIE); let token = match cookies { - Some(cookies) => { - cookies - .to_str() - .unwrap_or("") - .split("; ") - .find(|x| x.starts_with("token=")) - .unwrap_or("") - .strip_prefix("token=") - .unwrap_or("")}, - None => "" + Some(cookies) => cookies + .to_str() + .unwrap_or("") + .split("; ") + .find(|x| x.starts_with("token=")) + .unwrap_or("") + .strip_prefix("token=") + .unwrap_or(""), + None => "", }; let pool = db.clone().lock().unwrap().clone(); - let user = sqlx::query!(r#"SELECT permissions FROM users WHERE token=?1"#, token).fetch_optional(&pool).await; + let user = sqlx::query!(r#"SELECT permissions FROM users WHERE token=?1"#, token) + .fetch_optional(&pool) + .await; match user { - Ok(Some(user)) => user.permissions, - _ => 0 + Ok(Some(user)) => { + let perm = user.permissions as u8; + perm >= level + }, + _ => match level { + 0 => true, + _ => false + }, } } @@ -451,18 +587,21 @@ fn check_password(password: &String) -> bool { up && num && sym } -async fn not_found() -> Result, Error> { +async fn not_found() -> Result, Error> { let mut file_path = env::current_dir().expect("Could not get app directory."); file_path.push("static/html/404.html"); let mut file = File::open(file_path).unwrap(); let mut buf = Vec::new(); file.read_to_end(&mut buf).unwrap(); - Ok(Response::builder().status(StatusCode::NOT_FOUND).body(ResponseBody::Full(Full::new(Bytes::from(buf)))).unwrap()) + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::new(buf)) + .unwrap()) } async fn req_json(req: Request) -> Option where - T: DeserializeOwned + T: DeserializeOwned, { let body = req.into_body().collect().await.unwrap(); match from_reader(body.aggregate().reader()) { @@ -472,10 +611,13 @@ where } fn get_settings() -> Settings { - let mut settings_path = env::current_dir().expect("Could not get app directory. (Required to read settings)"); + let mut settings_path = + env::current_dir().expect("Could not get app directory. (Required to read settings)"); settings_path.push("settings.json"); - let settings_file = File::open(settings_path).expect("Could not open settings file, does it exists?"); - let settings: Settings = from_reader(settings_file).expect("Could not parse settings, please check syntax."); + let settings_file = + File::open(settings_path).expect("Could not open settings file, does it exists?"); + let settings: Settings = + from_reader(settings_file).expect("Could not parse settings, please check syntax."); settings } @@ -504,26 +646,32 @@ fn main() { #[tokio::main] async fn run() { let settings = get_settings(); - let db_pool = Arc::new(Mutex::new(SqlitePool::connect(&settings.database_url).await.expect("Could not connect to database. Make sure the url is correct."))); - let bind_address: SocketAddr = SocketAddr::from_str(&settings.bind_address).expect("Could not parse bind address."); - let listener = TcpListener::bind(bind_address).await.expect("Could not bind to address."); + let db_pool = Arc::new(Mutex::new( + SqlitePool::connect(&settings.database_url) + .await + .expect("Could not connect to database. Make sure the url is correct."), + )); + let bind_address: SocketAddr = + SocketAddr::from_str(&settings.bind_address).expect("Could not parse bind address."); + let listener = TcpListener::bind(bind_address) + .await + .expect("Could not bind to address."); loop { - let (stream, _) = listener.accept().await.expect("Could not accept incoming stream."); + let (stream, _) = listener + .accept() + .await + .expect("Could not accept incoming stream."); let io = TokioIo::new(stream); let db = db_pool.clone(); - let service = service_fn( - move |req| { - service(req, db.clone()) + let service = service_fn(move |req| service(req, db.clone())); + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new() + .timer(TokioTimer::new()) + .serve_connection(io, service) + .await + { + println!("Failed to serve connection: {:?}", err); } - ); - tokio::task::spawn( - async move { - if let Err(err) = http1::Builder::new() - .timer(TokioTimer::new()) - .serve_connection(io, service).await { - println!("Failed to serve connection: {:?}", err); - } - } - ); + }); } }