diff --git a/server/Cargo.lock b/server/Cargo.lock index 80af018..2561a34 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -2323,6 +2323,7 @@ dependencies = [ "humanize-bytes", "int-enum", "log", + "rand 0.9.0", "rocket", "rocket-session-store", "rocket_dyn_templates", @@ -2331,6 +2332,7 @@ dependencies = [ "sqlx", "tokio", "tracing-subscriber", + "uuid", ] [[package]] @@ -2745,6 +2747,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" dependencies = [ + "getrandom 0.3.2", "serde", ] diff --git a/server/Cargo.toml b/server/Cargo.toml index 6ec9b8b..35bf3a7 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -17,4 +17,6 @@ int-enum = "1.2.0" dotenvy = "0.15.7" rocket_dyn_templates = { version = "0.2.0", features = ["handlebars"] } humanize-bytes = "1.0.6" -rocket-session-store = "0.2.1" \ No newline at end of file +rocket-session-store = "0.2.1" +uuid = { version = "1.16.0", features = ["v4"] } +rand = { version = "0.9.0", features = ["thread_rng"] } \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs index f732a99..963a20c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,3 +1,4 @@ +use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; use log::{debug, error, info, trace, warn}; @@ -38,11 +39,22 @@ pub type DB = Pool; const MAX_UPLOAD_SIZE: ByteUnit = ByteUnit::Mebibyte(100_000); -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Debug, Serialize, Default)] struct SessionData { - user: UserModel, + csrf_token: Option, + login: Option, +} +#[derive(Clone, Debug, Serialize)] +struct LoginSessionData { + user: UserModel, + ip_address: IpAddr, +} +#[derive(Clone, Debug, Serialize)] +struct SessionUser { + id: String, + name: String, + email: String } - #[launch] async fn rocket() -> _ { setup_logger(); @@ -113,7 +125,7 @@ async fn rocket() -> _ { .mount("/", routes![ ui::help::about, ui::user::index, ui::user::redirect_list_library_files, ui::user::list_library_files, ui::user::get_library_file, - ui::help::test_get, ui::help::test_set + ui::help::test_get ]) } diff --git a/server/src/models/user.rs b/server/src/models/user.rs index 0ba0969..2d6900c 100644 --- a/server/src/models/user.rs +++ b/server/src/models/user.rs @@ -8,6 +8,8 @@ use crate::models::repo::RepoModel; #[derive(Serialize, Clone, Debug)] pub struct UserModel { pub id: Uuid, + // email + // password pub created_at: NaiveDateTime, pub name: String } diff --git a/server/src/routes/ui/auth.rs b/server/src/routes/ui/auth.rs index 698e0d3..2c38c33 100644 --- a/server/src/routes/ui/auth.rs +++ b/server/src/routes/ui/auth.rs @@ -1,26 +1,63 @@ -use rocket::{get, post, Route}; +use std::net::IpAddr; +use rocket::{get, post, FromForm, Route}; +use rocket::form::Form; use rocket_dyn_templates::{context, Template}; +use rocket_session_store::Session; +use crate::models::user::UserModel; +use crate::{LoginSessionData, SessionData}; +use crate::util::{gen_csrf_token, set_csrf, validate_csrf}; #[get("/login")] -pub async fn login(route: &Route) -> Template { - Template::render("auth/login", context! { route: route.uri.path() }) - +pub async fn login(route: &Route, session: Session<'_, SessionData>) -> Template { + let csrf_token = set_csrf(&session).await; + Template::render("auth/login", context! { + route: route.uri.path(), + csrf_token: csrf_token + }) } -#[post("/login")] -pub async fn login_handler(route: &Route) -> Template { - Template::render("auth/login", context! { route: route.uri.path() }) + +#[derive(FromForm)] +struct LoginForm<'r> { + _csrf: &'r str, + username: &'r str, + password: &'r str, + #[field(default = false)] + remember_me: bool +} +#[post("/login", data = "
")] +pub async fn login_handler(route: &Route, ip_addr: IpAddr, form: Form>, session: Session<'_, SessionData>) -> String { + if let Ok(true) = validate_csrf(&session, &form._csrf).await { + if let Ok(sess) = session.get().await.map(|s| s.unwrap_or_default()) { + session.set(SessionData { + csrf_token: None, + login: Some(LoginSessionData { + user: UserModel { + id: Default::default(), + created_at: Default::default(), + name: form.username.to_string(), + }, + ip_address: ip_addr, + }), + }).await.unwrap(); + return format!("login success") + } + } + format!("login bad. csrf failed!") } #[get("/register")] -pub async fn register(route: &Route) -> Template { - Template::render("auth/register", context! { route: route.uri.path() }) - +pub async fn register(route: &Route, session: Session<'_, SessionData>) -> Template { + let csrf_token = set_csrf(&session).await; + Template::render("auth/register", context! { + route: route.uri.path(), + csrf_token: csrf_token + }) } #[post("/register")] -pub async fn register_handler(route: &Route) -> Template { +pub async fn register_handler(route: &Route, session: Session<'_, SessionData>) -> Template { Template::render("auth/register", context! { route: route.uri.path() }) } \ No newline at end of file diff --git a/server/src/routes/ui/help.rs b/server/src/routes/ui/help.rs index 6b7fb63..dd283d5 100644 --- a/server/src/routes/ui/help.rs +++ b/server/src/routes/ui/help.rs @@ -11,18 +11,6 @@ pub fn about(route: &Route) -> Template { Template::render("about", context! { route: route.uri.path() }) } -// TODO: temp remove when not needed -#[get("/test/set")] -pub async fn test_set(session: Session<'_, SessionData>) -> &str { - session.set(SessionData { - user: UserModel { - id: Default::default(), - created_at: Default::default(), - name: "Jackie".to_string(), - }, - }).await; - "set." -} #[get("/test/get")] pub async fn test_get(session: Session<'_, SessionData>) -> Result, String> { diff --git a/server/src/util.rs b/server/src/util.rs index 6f612e5..2558b6c 100644 --- a/server/src/util.rs +++ b/server/src/util.rs @@ -1,14 +1,21 @@ use std::fs; use std::io::Cursor; +use log::trace; +use rand::rngs::OsRng; +use rand::{rng, Rng, TryRngCore}; +use rand::distr::Alphanumeric; use rocket::http::{ContentType, Status}; use rocket::{response, Request, Response}; use rocket::fs::relative; use rocket::response::Responder; use rocket::serde::Serialize; use rocket_dyn_templates::handlebars::Handlebars; +use rocket_session_store::{Session, SessionError, SessionResult}; use sqlx::Error; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; +use uuid::Uuid; +use crate::SessionData; use crate::util::ResponseError::DatabaseError; pub(crate) fn setup_logger() { @@ -21,6 +28,41 @@ pub(crate) fn setup_logger() { .init(); } +pub async fn set_csrf(session: &Session<'_, SessionData>) -> String { + let token = gen_csrf_token(); + trace!("set_csrf token={}", token); + let mut sess = session.get().await.expect("failed to get session data") + .unwrap_or_else(|| SessionData { + csrf_token: None, + login: None, + }); + sess.csrf_token = Some(token.clone()); + session.set(sess).await.unwrap(); + token +} + +pub async fn validate_csrf(session: &Session<'_, SessionData>, form_csrf_token: &str) -> Result { + if let Some(mut sess) = session.get().await? { + if let Some(sess_token) = sess.csrf_token { + let success = sess_token == form_csrf_token; + if success { + sess.csrf_token = None; + session.set(sess).await?; + return Ok(true) + } + } + } + Ok(false) +} + +pub fn gen_csrf_token() -> String { + rng() + .sample_iter(&Alphanumeric) + .map(char::from) // map added here + .take(30) + .collect() +} + // pub(crate) fn setup_template_engine() -> Handlebars<'static> { // let mut hb = Handlebars::new(); // #[cfg(debug_assertions)] diff --git a/server/templates/auth/login.html.hbs b/server/templates/auth/login.html.hbs index a1fdbff..a00623a 100644 --- a/server/templates/auth/login.html.hbs +++ b/server/templates/auth/login.html.hbs @@ -5,6 +5,7 @@

Login

+
@@ -28,7 +29,7 @@
diff --git a/server/templates/auth/register.html.hbs b/server/templates/auth/register.html.hbs index e2c5292..5b276c1 100644 --- a/server/templates/auth/register.html.hbs +++ b/server/templates/auth/register.html.hbs @@ -6,6 +6,7 @@

Register

{{#if can_register }} +