add handler code and basic csrf protection

This commit is contained in:
Jackzie 2025-04-15 21:36:45 -05:00
parent 1c561e58f7
commit dc21e50a8a
9 changed files with 117 additions and 29 deletions

3
server/Cargo.lock generated
View file

@ -2323,6 +2323,7 @@ dependencies = [
"humanize-bytes", "humanize-bytes",
"int-enum", "int-enum",
"log", "log",
"rand 0.9.0",
"rocket", "rocket",
"rocket-session-store", "rocket-session-store",
"rocket_dyn_templates", "rocket_dyn_templates",
@ -2331,6 +2332,7 @@ dependencies = [
"sqlx", "sqlx",
"tokio", "tokio",
"tracing-subscriber", "tracing-subscriber",
"uuid",
] ]
[[package]] [[package]]
@ -2745,6 +2747,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
dependencies = [ dependencies = [
"getrandom 0.3.2",
"serde", "serde",
] ]

View file

@ -17,4 +17,6 @@ int-enum = "1.2.0"
dotenvy = "0.15.7" dotenvy = "0.15.7"
rocket_dyn_templates = { version = "0.2.0", features = ["handlebars"] } rocket_dyn_templates = { version = "0.2.0", features = ["handlebars"] }
humanize-bytes = "1.0.6" humanize-bytes = "1.0.6"
rocket-session-store = "0.2.1" rocket-session-store = "0.2.1"
uuid = { version = "1.16.0", features = ["v4"] }
rand = { version = "0.9.0", features = ["thread_rng"] }

View file

@ -1,3 +1,4 @@
use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
@ -38,11 +39,22 @@ pub type DB = Pool<Postgres>;
const MAX_UPLOAD_SIZE: ByteUnit = ByteUnit::Mebibyte(100_000); const MAX_UPLOAD_SIZE: ByteUnit = ByteUnit::Mebibyte(100_000);
#[derive(Clone, Debug, Serialize)] #[derive(Clone, Debug, Serialize, Default)]
struct SessionData { struct SessionData {
user: UserModel, csrf_token: Option<String>,
login: Option<LoginSessionData>,
}
#[derive(Clone, Debug, Serialize)]
struct LoginSessionData {
user: UserModel,
ip_address: IpAddr,
}
#[derive(Clone, Debug, Serialize)]
struct SessionUser {
id: String,
name: String,
email: String
} }
#[launch] #[launch]
async fn rocket() -> _ { async fn rocket() -> _ {
setup_logger(); setup_logger();
@ -113,7 +125,7 @@ async fn rocket() -> _ {
.mount("/", routes![ .mount("/", routes![
ui::help::about, ui::help::about,
ui::user::index, ui::user::redirect_list_library_files, ui::user::list_library_files, ui::user::get_library_file, ui::user::index, ui::user::redirect_list_library_files, ui::user::list_library_files, ui::user::get_library_file,
ui::help::test_get, ui::help::test_set ui::help::test_get
]) ])
} }

View file

@ -8,6 +8,8 @@ use crate::models::repo::RepoModel;
#[derive(Serialize, Clone, Debug)] #[derive(Serialize, Clone, Debug)]
pub struct UserModel { pub struct UserModel {
pub id: Uuid, pub id: Uuid,
// email
// password
pub created_at: NaiveDateTime, pub created_at: NaiveDateTime,
pub name: String pub name: String
} }

View file

@ -1,26 +1,63 @@
use rocket::{get, post, Route}; use std::net::IpAddr;
use rocket::{get, post, FromForm, Route};
use rocket::form::Form;
use rocket_dyn_templates::{context, Template}; use rocket_dyn_templates::{context, Template};
use rocket_session_store::Session;
use crate::models::user::UserModel;
use crate::{LoginSessionData, SessionData};
use crate::util::{gen_csrf_token, set_csrf, validate_csrf};
#[get("/login")] #[get("/login")]
pub async fn login(route: &Route) -> Template { pub async fn login(route: &Route, session: Session<'_, SessionData>) -> Template {
Template::render("auth/login", context! { route: route.uri.path() }) let csrf_token = set_csrf(&session).await;
Template::render("auth/login", context! {
route: route.uri.path(),
csrf_token: csrf_token
})
} }
#[post("/login")]
pub async fn login_handler(route: &Route) -> Template {
Template::render("auth/login", context! { route: route.uri.path() })
#[derive(FromForm)]
struct LoginForm<'r> {
_csrf: &'r str,
username: &'r str,
password: &'r str,
#[field(default = false)]
remember_me: bool
}
#[post("/login", data = "<form>")]
pub async fn login_handler(route: &Route, ip_addr: IpAddr, form: Form<LoginForm<'_>>, session: Session<'_, SessionData>) -> String {
if let Ok(true) = validate_csrf(&session, &form._csrf).await {
if let Ok(sess) = session.get().await.map(|s| s.unwrap_or_default()) {
session.set(SessionData {
csrf_token: None,
login: Some(LoginSessionData {
user: UserModel {
id: Default::default(),
created_at: Default::default(),
name: form.username.to_string(),
},
ip_address: ip_addr,
}),
}).await.unwrap();
return format!("login success")
}
}
format!("login bad. csrf failed!")
} }
#[get("/register")] #[get("/register")]
pub async fn register(route: &Route) -> Template { pub async fn register(route: &Route, session: Session<'_, SessionData>) -> Template {
Template::render("auth/register", context! { route: route.uri.path() }) let csrf_token = set_csrf(&session).await;
Template::render("auth/register", context! {
route: route.uri.path(),
csrf_token: csrf_token
})
} }
#[post("/register")] #[post("/register")]
pub async fn register_handler(route: &Route) -> Template { pub async fn register_handler(route: &Route, session: Session<'_, SessionData>) -> Template {
Template::render("auth/register", context! { route: route.uri.path() }) Template::render("auth/register", context! { route: route.uri.path() })
} }

View file

@ -11,18 +11,6 @@ pub fn about(route: &Route) -> Template {
Template::render("about", context! { route: route.uri.path() }) Template::render("about", context! { route: route.uri.path() })
} }
// TODO: temp remove when not needed
#[get("/test/set")]
pub async fn test_set(session: Session<'_, SessionData>) -> &str {
session.set(SessionData {
user: UserModel {
id: Default::default(),
created_at: Default::default(),
name: "Jackie".to_string(),
},
}).await;
"set."
}
#[get("/test/get")] #[get("/test/get")]
pub async fn test_get(session: Session<'_, SessionData>) -> Result<Json<SessionData>, String> { pub async fn test_get(session: Session<'_, SessionData>) -> Result<Json<SessionData>, String> {

View file

@ -1,14 +1,21 @@
use std::fs; use std::fs;
use std::io::Cursor; use std::io::Cursor;
use log::trace;
use rand::rngs::OsRng;
use rand::{rng, Rng, TryRngCore};
use rand::distr::Alphanumeric;
use rocket::http::{ContentType, Status}; use rocket::http::{ContentType, Status};
use rocket::{response, Request, Response}; use rocket::{response, Request, Response};
use rocket::fs::relative; use rocket::fs::relative;
use rocket::response::Responder; use rocket::response::Responder;
use rocket::serde::Serialize; use rocket::serde::Serialize;
use rocket_dyn_templates::handlebars::Handlebars; use rocket_dyn_templates::handlebars::Handlebars;
use rocket_session_store::{Session, SessionError, SessionResult};
use sqlx::Error; use sqlx::Error;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use uuid::Uuid;
use crate::SessionData;
use crate::util::ResponseError::DatabaseError; use crate::util::ResponseError::DatabaseError;
pub(crate) fn setup_logger() { pub(crate) fn setup_logger() {
@ -21,6 +28,41 @@ pub(crate) fn setup_logger() {
.init(); .init();
} }
pub async fn set_csrf(session: &Session<'_, SessionData>) -> String {
let token = gen_csrf_token();
trace!("set_csrf token={}", token);
let mut sess = session.get().await.expect("failed to get session data")
.unwrap_or_else(|| SessionData {
csrf_token: None,
login: None,
});
sess.csrf_token = Some(token.clone());
session.set(sess).await.unwrap();
token
}
pub async fn validate_csrf(session: &Session<'_, SessionData>, form_csrf_token: &str) -> Result<bool, SessionError> {
if let Some(mut sess) = session.get().await? {
if let Some(sess_token) = sess.csrf_token {
let success = sess_token == form_csrf_token;
if success {
sess.csrf_token = None;
session.set(sess).await?;
return Ok(true)
}
}
}
Ok(false)
}
pub fn gen_csrf_token() -> String {
rng()
.sample_iter(&Alphanumeric)
.map(char::from) // map added here
.take(30)
.collect()
}
// pub(crate) fn setup_template_engine() -> Handlebars<'static> { // pub(crate) fn setup_template_engine() -> Handlebars<'static> {
// let mut hb = Handlebars::new(); // let mut hb = Handlebars::new();
// #[cfg(debug_assertions)] // #[cfg(debug_assertions)]

View file

@ -5,6 +5,7 @@
<div class="box is-radiusless"> <div class="box is-radiusless">
<h4 class="title is-4 has-text-centered">Login</h4> <h4 class="title is-4 has-text-centered">Login</h4>
<form method="post" action="/auth/login"> <form method="post" action="/auth/login">
<input type="hidden" name="_csrf" value="{{ csrf_token }}">
<div class="field"> <div class="field">
<label class="label">Username / Email</label> <label class="label">Username / Email</label>
<div class="control has-icons-left"> <div class="control has-icons-left">
@ -28,7 +29,7 @@
<div class="field"> <div class="field">
<div class="control"> <div class="control">
<label class="checkbox"> <label class="checkbox">
<input name="remember" type="checkbox"> <input name="remember_me" type="checkbox">
Remember Me</a> Remember Me</a>
</label> </label>
</div> </div>

View file

@ -6,6 +6,7 @@
<h4 class="title is-4 has-text-centered">Register</h4> <h4 class="title is-4 has-text-centered">Register</h4>
{{#if can_register }} {{#if can_register }}
<form method="post" action="/auth/register"> <form method="post" action="/auth/register">
<input type="hidden" name="_csrf" value="{{ csrf_token }}">
<div class="field"> <div class="field">
<label class="label">Username</label> <label class="label">Username</label>
<div class="control has-icons-left"> <div class="control has-icons-left">