diff --git a/src/managers/user.rs b/src/managers/user.rs index d91caa3..37354e2 100644 --- a/src/managers/user.rs +++ b/src/managers/user.rs @@ -1,4 +1,5 @@ use std::hash::{DefaultHasher, Hash, Hasher}; +use std::net::IpAddr; use std::sync::Arc; use anyhow::anyhow; use rocket::futures::TryStreamExt; @@ -10,7 +11,7 @@ use sqlx::{query, query_as, Pool, QueryBuilder}; use uuid::Uuid; use crate::config::AppConfig; use crate::consts::ENCRYPTION_ROUNDS; -use crate::{SessionData, DB}; +use crate::{LoginSessionData, SessionData, DB}; use crate::models::user::{UserAuthError, UserModel}; pub struct UserManager { @@ -80,17 +81,17 @@ impl UserManager { .map_err(|e| anyhow!(e)) } /// Returns user's id - pub async fn create_normal_user(&self, user: CreateUserOptions, plain_password: String) -> Result { + pub async fn create_normal_user(&self, user: CreateUserOptions, plain_password: String) -> Result { let password = bcrypt::hash(plain_password, ENCRYPTION_ROUNDS) .map_err(|e| anyhow!(e))?; let id = Self::generate_id(None); self.create_user(id, user, Some(password)).await } /// Returns user's id - pub async fn create_sso_user(&self, user: CreateUserOptions, id: String) -> Result { + pub async fn create_sso_user(&self, user: CreateUserOptions, id: String) -> Result { self.create_user(id, user, None).await } - async fn create_user(&self, id: String, user: CreateUserOptions, encrypted_password: Option) -> Result { + async fn create_user(&self, id: String, user: CreateUserOptions, encrypted_password: Option) -> Result { query!( "INSERT INTO storage.users (id, name, password, email, username) VALUES ($1, $2, $3, $4, $5)", id, @@ -101,6 +102,22 @@ impl UserManager { ) .execute(&self.pool) .await?; - Ok(id) + Ok(UserModel { + id, + username: user.username, + email: user.email, + created_at: Default::default(), + name: user.name, + }) + } + + pub async fn login_user(&self, user: UserModel, ip_address: IpAddr, sessions: Session<'_, SessionData>) { + sessions.set(SessionData { + csrf_token: None, + login: Some(LoginSessionData { + user, + ip_address + }), + }).await.unwrap(); } } \ No newline at end of file diff --git a/src/routes/ui/auth/sso.rs b/src/routes/ui/auth/sso.rs index 0f88a18..b461f27 100644 --- a/src/routes/ui/auth/sso.rs +++ b/src/routes/ui/auth/sso.rs @@ -101,7 +101,7 @@ async fn callback_handler(sso: &State, ip: IpAddr, code: String, state } #[get("/auth/sso/cb?&")] -pub async fn callback(config: &State, users: &State, ip: IpAddr, sso: &State, code: String, state: String) -> Result { +pub async fn callback(sessions: Session<'_, SessionData>, config: &State, users: &State, ip: IpAddr, sso: &State, code: String, state: String) -> Result { let (userinfo, provider_id, return_to) = callback_handler(sso, ip, code, state).await .map_err(|e| (Status::InternalServerError, Template::render("errors/500", context! { error: e.to_string() @@ -120,7 +120,7 @@ pub async fn callback(config: &State, users: &State, ip: error: "Provider did not provide an username" })))?.to_string(); let search_options = vec![FindUserOption::Id(uid.clone()), FindUserOption::Email(email.clone()), FindUserOption::Username(username.clone())]; - let user = users.fetch_user(&search_options).await.map_err(|e|(Status::InternalServerError, Template::render("errors/500", context! { + let mut user = users.fetch_user(&search_options).await.map_err(|e|(Status::InternalServerError, Template::render("errors/500", context! { error: format!("Failed to find user: {}", e) })))?; debug!("existing user = {:?}", user); @@ -130,13 +130,18 @@ pub async fn callback(config: &State, users: &State, ip: error: "No account found linked to oidc provider and account creation has been disabled" }))); } - let id = users.create_sso_user(CreateUserOptions { - email, - username, - name: userinfo.name().unwrap().get(None).map(|s| s.to_string()), - }, uid).await.expect("later i fix"); - debug!("new user = {}", id); + user = { + let u = users.create_sso_user(CreateUserOptions { + email, + username, + name: userinfo.name().unwrap().get(None).map(|s| s.to_string()), + }, uid).await.expect("later i fix"); + debug!("new user = {}", u.id); + Some(u) + } } + let user = user.unwrap(); + users.login_user(user, ip, sessions).await; debug!("user={:?}\nemail={:?}\nname={:?}", userinfo.subject(), userinfo.email(), userinfo.name()); // TODO: login user to session, prob through UserManager/users let return_to = return_to.unwrap_or("/".to_string());