Custom Response Body, is authorised returning bool, ...

This commit is contained in:
2024-10-02 19:36:50 +02:00
parent 72fd45eb64
commit 894e102322

View File

@@ -1,25 +1,22 @@
use argon2::{ use argon2::{
password_hash::{ password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
rand_core::OsRng,
PasswordHash, PasswordHasher, PasswordVerifier, SaltString,
},
Argon2, Argon2,
}; };
use bytes::Buf; use bytes::{Buf, Bytes};
use bytes::Bytes;
use chrono::{DateTime, Days, Utc}; use chrono::{DateTime, Days, Utc};
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
use daemonize::Daemonize; use daemonize::Daemonize;
use futures; use http_body_util::{BodyExt, Full};
use http_body_util::{BodyExt, Empty, Full}; use hyper::{
use hyper::body::{Body, Frame, Incoming, SizeHint}; body::{Body as HyperBody, Incoming, Frame},
use hyper::header::{LOCATION, SET_COOKIE, COOKIE}; header::{COOKIE, SET_COOKIE},
use hyper::server::conn::http1; server::conn::http1,
use hyper::service::service_fn; service::service_fn,
use hyper::{Error, Method, Request, Response, StatusCode}; Error, Method, Request, Response, StatusCode,
};
use hyper_util::rt::{TokioIo, TokioTimer}; use hyper_util::rt::{TokioIo, TokioTimer};
use rand::distributions::{Alphanumeric, DistString}; use rand::distributions::{Alphanumeric, DistString};
use serde::{Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{from_reader, Value}; use serde_json::{from_reader, Value};
use sqlx::sqlite::SqlitePool; use sqlx::sqlite::SqlitePool;
use std::collections::HashMap; use std::collections::HashMap;
@@ -33,27 +30,32 @@ use std::str::FromStr;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::SystemTime; use std::time::SystemTime;
use hyper::http::HeaderValue;
use serde::de::DeserializeOwned;
use tokio::net::TcpListener; use tokio::net::TcpListener;
// Some functions could return an empty body, will try using the following enum: enum Body {
enum ResponseBody {
Full(Full<Bytes>), Full(Full<Bytes>),
Empty, Empty,
} }
impl Body for ResponseBody { impl Body {
fn new<T>(data: T) -> Self
where
Bytes: From<T>,
{
Body::Full(Full::new(Bytes::from(data)))
}
}
impl HyperBody for Body {
type Data = Bytes; type Data = Bytes;
type Error = Infallible; type Error = Infallible;
fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match &mut *self.get_mut() { match &mut *self.get_mut() {
Self::Full(incoming) => { Self::Full(incoming) => Pin::new(incoming).poll_frame(cx),
Pin::new(incoming).poll_frame(cx) Self::Empty => Poll::Ready(None),
},
Self::Empty => {
Poll::Ready(None)
}
} }
} }
} }
@@ -93,15 +95,21 @@ struct Settings {
bind_address: String, bind_address: String,
} }
async fn service(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn service(
req: Request<Incoming>,
db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
match req.method() { match req.method() {
&Method::GET => { get(req, db).await } &Method::GET => get(req, db).await,
&Method::POST => { post(req, db).await } &Method::POST => post(req, db).await,
_ => { Ok(Response::builder().status(StatusCode::IM_A_TEAPOT).body(ResponseBody::Empty).unwrap()) } _ => Ok(Response::builder()
.status(StatusCode::IM_A_TEAPOT)
.body(Body::Empty)
.unwrap()),
} }
} }
async fn get(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn get(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<Body>, Error> {
let path = req.uri().path(); let path = req.uri().path();
if path.starts_with("/static") { if path.starts_with("/static") {
get_file(path).await get_file(path).await
@@ -114,25 +122,38 @@ async fn get(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Respo
} }
} }
async fn get_page(req: &Request<Incoming>, path: &str, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn get_page(
let mut routes = env::current_dir().expect("Could not get app directory (Required to get routes)"); req: &Request<Incoming>,
path: &str,
db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
let mut routes =
env::current_dir().expect("Could not get app directory (Required to get routes)");
routes.push("routes.json"); routes.push("routes.json");
let file = File::open(routes).expect("Could not open routes file."); let file = File::open(routes).expect("Could not open routes file.");
let map: Value = from_reader(file).expect("Could not parse routes, please verify syntax."); let map: Value = from_reader(file).expect("Could not parse routes, please verify syntax.");
match map.get(path) { match map.get(path) {
Some(Value::Object(s)) => { Some(Value::Object(s)) => {
let perm = is_authorised(req, db).await; let authorised = is_authorised(req, db, s.get("permission").unwrap().as_u64().unwrap() as u8).await;
if s.get("permission").unwrap().as_i64().unwrap() <= perm { if authorised {
get_file(s.get("file").unwrap().as_str().unwrap()).await get_file(s.get("file").unwrap().as_str().unwrap()).await
} else { } else {
get_file(map.get("/unauthorised").unwrap().get("file").unwrap().as_str().unwrap()).await get_file(
map.get("/unauthorised")
.unwrap()
.get("file")
.unwrap()
.as_str()
.unwrap(),
)
.await
} }
}, }
_ => not_found().await _ => not_found().await,
} }
} }
async fn get_file(mut path: &str) -> Result<Response<ResponseBody>, Error> { async fn get_file(mut path: &str) -> Result<Response<Body>, Error> {
let mut file_path = env::current_dir().expect("Could not get app directory."); let mut file_path = env::current_dir().expect("Could not get app directory.");
if path.starts_with(r"/") { if path.starts_with(r"/") {
path = path.strip_prefix(r"/").unwrap(); path = path.strip_prefix(r"/").unwrap();
@@ -147,29 +168,52 @@ async fn get_file(mut path: &str) -> Result<Response<ResponseBody>, Error> {
"js" => "text/javascript", "js" => "text/javascript",
"html" => "text/html", "html" => "text/html",
"css" => "text/css", "css" => "text/css",
_ => "" _ => "",
}; };
Ok(Response::builder().header("content-type", content_type).body(ResponseBody::Full(Full::new(Bytes::from(buf)))).unwrap()) Ok(Response::builder()
.header("content-type", content_type)
.body(Body::new(buf))
.unwrap())
} }
Err(_) => not_found().await Err(_) => not_found().await,
} }
} }
async fn get_data(path: &str, req: &Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn get_data(
path: &str,
req: &Request<Incoming>,
db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
let pool = db.clone().lock().unwrap().clone(); let pool = db.clone().lock().unwrap().clone();
match path { match path {
"/data/players" => { "/data/players" => {
let items = sqlx::query!(r#"SELECT * FROM players"#).fetch_all(&pool).await.unwrap(); let items = sqlx::query!(r#"SELECT * FROM players"#)
let players: Vec<Player> = items.iter().map(|x| Player { id: x.id, name: x.name.clone() }).collect(); .fetch_all(&pool)
Ok(Response::new(ResponseBody::Full(Full::new(Bytes::from(serde_json::to_string(&players).unwrap()))))) .await
.unwrap();
let players: Vec<Player> = items
.iter()
.map(|x| Player {
id: x.id,
name: x.name.clone(),
})
.collect();
Ok(Response::new(Body::new(
serde_json::to_string(&players).unwrap(),
)))
} }
"/data/votes" => { "/data/votes" => {
let votes = get_votes(req, db).await; let votes = get_votes(req, db).await;
Ok(Response::new(ResponseBody::Full(Full::new(Bytes::from(serde_json::to_string(&votes).unwrap()))))) Ok(Response::new(Body::new(
serde_json::to_string(&votes).unwrap(),
)))
} }
"/data/results" => { "/data/results" => {
let votes = get_votes(req, db).await; let votes = get_votes(req, db).await;
let ids: Vec<(i64, i64)> = votes.iter().map(|x| (x.plus_player_id, x.minus_player_id)).collect(); let ids: Vec<(i64, i64)> = votes
.iter()
.map(|x| (x.plus_player_id, x.minus_player_id))
.collect();
let mut plus_results: HashMap<i64, i64> = HashMap::new(); let mut plus_results: HashMap<i64, i64> = HashMap::new();
let mut minus_results: HashMap<i64, i64> = HashMap::new(); let mut minus_results: HashMap<i64, i64> = HashMap::new();
@@ -191,14 +235,16 @@ async fn get_data(path: &str, req: &Request<Incoming>, db: Arc<Mutex<SqlitePool>
let mut plus_results: Vec<(i64, i64)> = plus_results.into_iter().collect(); let mut plus_results: Vec<(i64, i64)> = plus_results.into_iter().collect();
let mut minus_results: Vec<(i64, i64)> = minus_results.into_iter().collect(); let mut minus_results: Vec<(i64, i64)> = minus_results.into_iter().collect();
plus_results.sort_by(|a, b| { b.1.cmp(&a.1) }); plus_results.sort_by(|a, b| b.1.cmp(&a.1));
minus_results.sort_by(|a, b| { b.1.cmp(&a.1) }); minus_results.sort_by(|a, b| b.1.cmp(&a.1));
let sorted_results = vec![plus_results, minus_results]; let sorted_results = vec![plus_results, minus_results];
Ok(Response::new(ResponseBody::Full(Full::new(Bytes::from(serde_json::to_string(&sorted_results).unwrap()))))) Ok(Response::new(Body::new(
serde_json::to_string(&sorted_results).unwrap(),
)))
} }
_ => not_found().await _ => not_found().await,
} }
} }
@@ -215,26 +261,39 @@ async fn get_votes(req: &Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Vec<V
DateTime::from_timestamp_millis(parsed_date.unwrap()) DateTime::from_timestamp_millis(parsed_date.unwrap())
} }
} }
None => Some(DateTime::from(SystemTime::now())) None => Some(DateTime::from(SystemTime::now())),
}; };
if date.is_none() { if date.is_none() {
return Vec::new(); return Vec::new();
} }
let formatted_date = format!("{}", date.unwrap().format("%d/%m/%Y")); let formatted_date = format!("{}", date.unwrap().format("%d/%m/%Y"));
let items = sqlx::query!(r#"SELECT * FROM votes WHERE submit_date = ?1 ORDER BY id"#, formatted_date).fetch_all(&pool).await.unwrap(); let items = sqlx::query!(
items.iter().map(|x| Vote { r#"SELECT * FROM votes WHERE submit_date = ?1 ORDER BY id"#,
plus_player_id: x.plus_player_id, formatted_date
plus_nickname: x.plus_nickname.clone(), )
plus_reason: x.plus_reason.clone(), .fetch_all(&pool)
minus_player_id: x.minus_player_id, .await
minus_nickname: x.minus_nickname.clone(), .unwrap();
minus_reason: x.minus_reason.clone(), items
}).collect() .iter()
.map(|x| Vote {
plus_player_id: x.plus_player_id,
plus_nickname: x.plus_nickname.clone(),
plus_reason: x.plus_reason.clone(),
minus_player_id: x.minus_player_id,
minus_nickname: x.minus_nickname.clone(),
minus_reason: x.minus_reason.clone(),
})
.collect()
} }
async fn get_admin(req: &Request<Incoming>, path: &str, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn get_admin(
let perm = is_authorised(req, db.clone()).await; req: &Request<Incoming>,
if perm < 3 { path: &str,
db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
let authorised = is_authorised(req, db.clone(), 3).await;
if authorised {
return not_found().await; return not_found().await;
} }
if path == "/admin" { if path == "/admin" {
@@ -242,43 +301,45 @@ async fn get_admin(req: &Request<Incoming>, path: &str, db: Arc<Mutex<SqlitePool
} }
if path == "/admin/users" { if path == "/admin/users" {
let pool = db.clone().lock().unwrap().clone(); let pool = db.clone().lock().unwrap().clone();
let users = sqlx::query!(r#"SELECT username, permissions FROM users"#).fetch_all(&pool).await.unwrap(); let users = sqlx::query!(r#"SELECT username, permissions FROM users"#)
let users: Vec<(String, i64)> = users.iter().map(|x| (x.username.clone(), x.permissions)).collect(); .fetch_all(&pool)
.await
.unwrap();
let users: Vec<(String, i64)> = users
.iter()
.map(|x| (x.username.clone(), x.permissions))
.collect();
let stringed = serde_json::to_string(&users).unwrap_or("".to_string()); let stringed = serde_json::to_string(&users).unwrap_or("".to_string());
return Ok(Response::builder().body(ResponseBody::Full(Full::new(Bytes::from(stringed)))).unwrap()); return Ok(Response::builder().body(Body::new(stringed)).unwrap());
} }
not_found().await not_found().await
} }
async fn post(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn post(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<Body>, Error> {
let path = req.uri().path(); let path = req.uri().path();
if path.starts_with("/admin") { if path.starts_with("/admin") {
return post_admin(req, db).await; return post_admin(req, db).await;
} }
match path { match path {
"/vote" => { "/vote" => post_vote(req, db).await,
post_vote(req, db).await "/login" => login(req, db).await,
} "/register" => register(req, db).await,
"/login" => { "/logout" => logout().await,
login(req, db).await _ => not_found().await,
}
"/register" => {
register(req, db).await
}
"/logout" => {
logout().await
}
_ => {
not_found().await
}
} }
} }
async fn post_vote(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn post_vote(
req: Request<Incoming>,
db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
let body = req.into_body().collect().await?; let body = req.into_body().collect().await?;
let data: Result<Vote, serde_json::Error> = from_reader(body.aggregate().reader()); let data: Result<Vote, serde_json::Error> = from_reader(body.aggregate().reader());
if data.is_err() { if data.is_err() {
return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::Empty)
.unwrap());
} }
let vote = data.unwrap(); let vote = data.unwrap();
let timestamp: DateTime<Utc> = DateTime::from(SystemTime::now()); let timestamp: DateTime<Utc> = DateTime::from(SystemTime::now());
@@ -295,40 +356,57 @@ async fn post_vote(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result
vote.minus_reason, vote.minus_reason,
formatted).execute(&mut *conn).await; formatted).execute(&mut *conn).await;
if result.is_err() { if result.is_err() {
return Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(ResponseBody::Empty).unwrap()); return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::Empty)
.unwrap());
} }
Ok(Response::builder().body(ResponseBody::Full(Full::new(Bytes::new()))).unwrap()) Ok(Response::builder().body(Body::Empty).unwrap())
} }
async fn post_admin(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn post_admin(
let perm = is_authorised(&req, db.clone()).await; req: Request<Incoming>,
if perm < 3 { db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
let authorised = is_authorised(&req, db.clone(), 3).await;
if authorised {
return get_page(&req, "/unauthorised", db).await; return get_page(&req, "/unauthorised", db).await;
} }
let path = req.uri().path(); let path = req.uri().path();
match path { match path {
"/admin/post/user" => { "/admin/post/user" => {
req_json::<User>(req).await; req_json::<User>(req).await;
}, }
"/admin/post/vote" => {}, "/admin/post/vote" => {}
"/admin/post/player" => {}, "/admin/post/player" => {}
_ => {} _ => {}
} }
not_found().await not_found().await
} }
async fn login(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn login(
req: Request<Incoming>,
db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
let body = req.into_body().collect().await; let body = req.into_body().collect().await;
let data: Result<Login, serde_json::Error> = from_reader(body?.aggregate().reader()); let data: Result<Login, serde_json::Error> = from_reader(body?.aggregate().reader());
if data.is_err() { if data.is_err() {
return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::Empty)
.unwrap());
} }
let data = data.unwrap(); let data = data.unwrap();
if !check_username(&data.username) { if !check_username(&data.username) {
return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Full(Full::new(Bytes::from("Bad Request")))).unwrap()); return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::Empty)
.unwrap());
} }
let pool = db.clone().lock().unwrap().clone(); let pool = db.clone().lock().unwrap().clone();
let result = sqlx::query!(r#"SELECT * FROM users WHERE username=?1"#, data.username).fetch_optional(&pool).await; let result = sqlx::query!(r#"SELECT * FROM users WHERE username=?1"#, data.username)
.fetch_optional(&pool)
.await;
match result { match result {
Ok(Some(user)) => { Ok(Some(user)) => {
let argon = Argon2::default(); let argon = Argon2::default();
@@ -339,81 +417,139 @@ async fn login(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Res
let date = date.checked_add_days(Days::new(7)).unwrap(); let date = date.checked_add_days(Days::new(7)).unwrap();
// With server side rendering, redirect here to "/" // With server side rendering, redirect here to "/"
Ok(Response::builder() Ok(Response::builder()
.header(SET_COOKIE, .header(
format!("token={}; Expires={}; Secure; HttpOnly; SameSite=Strict", user.token, date.to_rfc2822())) SET_COOKIE,
.header(SET_COOKIE, format!("logged=true; Expires={}; Secure; SameSite=Strict", date.to_rfc2822())) format!(
.body(ResponseBody::Full(Full::new(Bytes::from("Ok")))) "token={}; Expires={}; Secure; HttpOnly; SameSite=Strict",
user.token,
date.to_rfc2822()
),
)
.header(
SET_COOKIE,
format!(
"logged=true; Expires={}; Secure; SameSite=Strict",
date.to_rfc2822()
),
)
.body(Body::Empty)
.unwrap()) .unwrap())
} }
Err(_) => { Err(_) => Ok(Response::builder()
Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()) .status(StatusCode::BAD_REQUEST)
} .body(Body::Empty)
.unwrap()),
} }
} }
Ok(None) => { Ok(None) => Ok(Response::builder()
Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()) .status(StatusCode::BAD_REQUEST)
} .body(Body::Empty)
Err(_) => { Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(ResponseBody::Empty).unwrap()) } .unwrap()),
Err(_) => Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::Empty)
.unwrap()),
} }
} }
async fn register(req: Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> Result<Response<ResponseBody>, Error> { async fn register(
req: Request<Incoming>,
db: Arc<Mutex<SqlitePool>>,
) -> Result<Response<Body>, Error> {
match req_json::<Login>(req).await { match req_json::<Login>(req).await {
None => Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()), None => Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::Empty)
.unwrap()),
Some(login) => { Some(login) => {
if !check_username(&login.username) { if !check_username(&login.username) {
return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::Empty)
.unwrap());
} }
if !check_password(&login.password) { if !check_password(&login.password) {
return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::Empty)
.unwrap());
} }
let pool = db.clone().lock().unwrap().clone(); let pool = db.clone().lock().unwrap().clone();
let mut conn = pool.acquire().await.unwrap(); let mut conn = pool.acquire().await.unwrap();
let exists = sqlx::query!(r#"SELECT id FROM users WHERE username=?1"#, login.username).fetch_optional(&mut *conn).await; let exists = sqlx::query!(r#"SELECT id FROM users WHERE username=?1"#, login.username)
.fetch_optional(&mut *conn)
.await;
if exists.unwrap().is_some() { if exists.unwrap().is_some() {
return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(ResponseBody::Empty).unwrap()); return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::Empty)
.unwrap());
} }
let argon2 = Argon2::default(); let argon2 = Argon2::default();
let hash = argon2.hash_password(login.password.as_bytes(), &SaltString::generate(&mut OsRng)).unwrap().to_string(); let hash = argon2
.hash_password(login.password.as_bytes(), &SaltString::generate(&mut OsRng))
.unwrap()
.to_string();
let token = Alphanumeric.sample_string(&mut OsRng, 256); let token = Alphanumeric.sample_string(&mut OsRng, 256);
let result = sqlx::query!(r#"INSERT INTO users ( username, saltyhash, permissions, token) VALUES ( ?1, ?2, ?3, ?4 )"#, login.username, hash, 0, token).execute(&mut *conn).await; let result = sqlx::query!(r#"INSERT INTO users ( username, saltyhash, permissions, token) VALUES ( ?1, ?2, ?3, ?4 )"#, login.username, hash, 0, token).execute(&mut *conn).await;
match result { match result {
Ok(_) => Ok(Response::builder().body(ResponseBody::Full(Full::new(Bytes::new()))).unwrap()), Ok(_) => Ok(Response::builder().body(Body::Empty).unwrap()),
Err(_) => { Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(ResponseBody::Empty).unwrap()) } Err(_) => Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::Empty)
.unwrap()),
} }
}, }
} }
} }
async fn logout() -> Result<Response<ResponseBody>, Error> { async fn logout() -> Result<Response<Body>, Error> {
let date: DateTime<Utc> = DateTime::from(SystemTime::now()); let date: DateTime<Utc> = DateTime::from(SystemTime::now());
Ok(Response::builder() Ok(Response::builder()
//.status(StatusCode::SEE_OTHER) .header(
//.header(LOCATION, "/") SET_COOKIE,
.header(SET_COOKIE, format!("token=''; Expires={}; Secure; HttpOnly; SameSite=Strict", date.to_rfc2822())) format!(
.header(SET_COOKIE, format!("logged=false; Expires={}; Secure; HttpOnly; SameSite=Strict", date.to_rfc2822())) "token=''; Expires={}; Secure; HttpOnly; SameSite=Strict",
.body(ResponseBody::Empty).unwrap()) date.to_rfc2822()
),
)
.header(
SET_COOKIE,
format!(
"logged=false; Expires={}; Secure; HttpOnly; SameSite=Strict",
date.to_rfc2822()
),
)
.body(Body::Empty)
.unwrap())
} }
async fn is_authorised(req: &Request<Incoming>, db: Arc<Mutex<SqlitePool>>) -> i64 { async fn is_authorised(req: &Request<Incoming>, db: Arc<Mutex<SqlitePool>>, level: u8) -> bool {
let cookies = req.headers().get(COOKIE); let cookies = req.headers().get(COOKIE);
let token = match cookies { let token = match cookies {
Some(cookies) => { Some(cookies) => cookies
cookies .to_str()
.to_str() .unwrap_or("")
.unwrap_or("") .split("; ")
.split("; ") .find(|x| x.starts_with("token="))
.find(|x| x.starts_with("token=")) .unwrap_or("")
.unwrap_or("") .strip_prefix("token=")
.strip_prefix("token=") .unwrap_or(""),
.unwrap_or("")}, None => "",
None => ""
}; };
let pool = db.clone().lock().unwrap().clone(); let pool = db.clone().lock().unwrap().clone();
let user = sqlx::query!(r#"SELECT permissions FROM users WHERE token=?1"#, token).fetch_optional(&pool).await; let user = sqlx::query!(r#"SELECT permissions FROM users WHERE token=?1"#, token)
.fetch_optional(&pool)
.await;
match user { match user {
Ok(Some(user)) => user.permissions, Ok(Some(user)) => {
_ => 0 let perm = user.permissions as u8;
perm >= level
},
_ => match level {
0 => true,
_ => false
},
} }
} }
@@ -451,18 +587,21 @@ fn check_password(password: &String) -> bool {
up && num && sym up && num && sym
} }
async fn not_found() -> Result<Response<ResponseBody>, Error> { async fn not_found() -> Result<Response<Body>, Error> {
let mut file_path = env::current_dir().expect("Could not get app directory."); let mut file_path = env::current_dir().expect("Could not get app directory.");
file_path.push("static/html/404.html"); file_path.push("static/html/404.html");
let mut file = File::open(file_path).unwrap(); let mut file = File::open(file_path).unwrap();
let mut buf = Vec::new(); let mut buf = Vec::new();
file.read_to_end(&mut buf).unwrap(); file.read_to_end(&mut buf).unwrap();
Ok(Response::builder().status(StatusCode::NOT_FOUND).body(ResponseBody::Full(Full::new(Bytes::from(buf)))).unwrap()) Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::new(buf))
.unwrap())
} }
async fn req_json<T>(req: Request<Incoming>) -> Option<T> async fn req_json<T>(req: Request<Incoming>) -> Option<T>
where where
T: DeserializeOwned T: DeserializeOwned,
{ {
let body = req.into_body().collect().await.unwrap(); let body = req.into_body().collect().await.unwrap();
match from_reader(body.aggregate().reader()) { match from_reader(body.aggregate().reader()) {
@@ -472,10 +611,13 @@ where
} }
fn get_settings() -> Settings { fn get_settings() -> Settings {
let mut settings_path = env::current_dir().expect("Could not get app directory. (Required to read settings)"); let mut settings_path =
env::current_dir().expect("Could not get app directory. (Required to read settings)");
settings_path.push("settings.json"); settings_path.push("settings.json");
let settings_file = File::open(settings_path).expect("Could not open settings file, does it exists?"); let settings_file =
let settings: Settings = from_reader(settings_file).expect("Could not parse settings, please check syntax."); File::open(settings_path).expect("Could not open settings file, does it exists?");
let settings: Settings =
from_reader(settings_file).expect("Could not parse settings, please check syntax.");
settings settings
} }
@@ -504,26 +646,32 @@ fn main() {
#[tokio::main] #[tokio::main]
async fn run() { async fn run() {
let settings = get_settings(); let settings = get_settings();
let db_pool = Arc::new(Mutex::new(SqlitePool::connect(&settings.database_url).await.expect("Could not connect to database. Make sure the url is correct."))); let db_pool = Arc::new(Mutex::new(
let bind_address: SocketAddr = SocketAddr::from_str(&settings.bind_address).expect("Could not parse bind address."); SqlitePool::connect(&settings.database_url)
let listener = TcpListener::bind(bind_address).await.expect("Could not bind to address."); .await
.expect("Could not connect to database. Make sure the url is correct."),
));
let bind_address: SocketAddr =
SocketAddr::from_str(&settings.bind_address).expect("Could not parse bind address.");
let listener = TcpListener::bind(bind_address)
.await
.expect("Could not bind to address.");
loop { loop {
let (stream, _) = listener.accept().await.expect("Could not accept incoming stream."); let (stream, _) = listener
.accept()
.await
.expect("Could not accept incoming stream.");
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let db = db_pool.clone(); let db = db_pool.clone();
let service = service_fn( let service = service_fn(move |req| service(req, db.clone()));
move |req| { tokio::task::spawn(async move {
service(req, db.clone()) if let Err(err) = http1::Builder::new()
.timer(TokioTimer::new())
.serve_connection(io, service)
.await
{
println!("Failed to serve connection: {:?}", err);
} }
); });
tokio::task::spawn(
async move {
if let Err(err) = http1::Builder::new()
.timer(TokioTimer::new())
.serve_connection(io, service).await {
println!("Failed to serve connection: {:?}", err);
}
}
);
} }
} }