implement config + Pull SSO data from config

This commit is contained in:
Jackzie 2025-04-20 20:56:33 -05:00
parent cab39de312
commit e13f080d91
17 changed files with 440 additions and 209 deletions

View file

@ -1 +1,2 @@
DATABASE_URL=postgresql://server:5432/database?user=user&password=password&connectTimeout=30&currentSchema=storage; DATABASE_URL=postgresql://server:5432/database?user=user&password=password&connectTimeout=30&currentSchema=storage;
STORAGE_AUTH_OIDC_ISSUER_URL = ""

1
Cargo.lock generated
View file

@ -3284,6 +3284,7 @@ dependencies = [
"bcrypt", "bcrypt",
"chrono", "chrono",
"dotenvy", "dotenvy",
"figment",
"humanize-bytes", "humanize-bytes",
"int-enum", "int-enum",
"log", "log",

View file

@ -26,3 +26,4 @@ bcrypt = "0.17.0"
openidconnect = "4.0.0" openidconnect = "4.0.0"
reqwest = "0.12.15" reqwest = "0.12.15"
moka = { version = "0.12.10", features = ["future"] } moka = { version = "0.12.10", features = ["future"] }
figment = "0.10.19"

View file

@ -1,25 +1,41 @@
[general] [general]
listen_ip = "0.0.0.0" listen_ip = "0.0.0.0"
listen_port = 80 listen_port = 8080
# if under reverse proxy
# The public facing url, this is where users will access the app
# Used for OIDC callbacks
# - if under reverse proxy (nginx, traefik, caddy, etc):
#public_url = "https://storage.example.com" #public_url = "https://storage.example.com"
#public_port = 443
public_url = "http://localhost:8080" public_url = "http://localhost:8080"
public_port = 80
[backends.local] [backends.local]
path = "/var/tmp/test" path = "/var/tmp/test"
[auth] [auth]
enable_registration = true # Is account registration disabled? Users will not be able to create
oidc_enabled = true # a new account with email/username + pass
# Where the .well-known/openid-configuration exists disable_registration = false
oidc_issuer_url = "https://accounts.example.com" [auth.oidc]
oidc_client_id = "" enabled = true
oidc_client_secret = "" # The url the .well-known/openid-configuration exists, this can be a subpath
oidc_claims = [] # Example, for authentik: https://sso.example.com/application/o/YOURAPPSLUG
issuer_url = ""
client_id = ""
client_secret = ""
claims = ["email", "profile"]
# Should an account be created if SSO user id doesn't exist already # Should an account be created if SSO user id doesn't exist already
oidc_create_account = true create_account = true
# Should normal login (username/email+pass) be disabled, forcing users to use sso?
disable_normal_login = false
[smtp] [smtp]
# TODO: enabled = false
hostname = "smtp.example.com"
port = 587
username = ""
password = ""
# Name to be used for emails, defaults to public_url's domain
#from_name = ""
# The email address to send as, defaults to username
#from_address = ""
tls = "none" # "none", "starttls" or "tls"

View file

@ -1,27 +1,92 @@
use rocket::serde::{Serialize,Deserialize}; use std::collections::HashMap;
use std::env::var;
use figment::Figment;
use figment::providers::{Env, Format, Toml};
use log::error;
use openidconnect::core::{CoreClient, CoreProviderMetadata};
use openidconnect::IssuerUrl;
use openidconnect::url::Url;
use rocket::serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Config { #[serde(rename_all = "kebab-case")]
general: GeneralConfig, pub struct AppConfig {
auth: AuthConfig, pub general: GeneralConfig,
smtp: EmailConfig pub auth: AuthConfig,
pub smtp: Option<EmailConfig>
} }
pub fn get_settings() -> AppConfig {
let f = Figment::new()
.merge(Toml::file("config.toml"))
.merge(Env::prefixed("STORAGE_")
.map(|f| f.as_str().replace("__", "-").into()))
.extract();
match f {
Ok(settings) => settings,
Err(e) => {
error!("Failed to read configuration");
error!("{}", e);
std::process::exit(1);
}
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct GeneralConfig { pub struct GeneralConfig {
pub listen_ip: Option<String>, pub listen_ip: Option<String>,
pub listen_port: Option<u32> pub listen_port: Option<u16>,
pub public_url: String,
pub database_url: Option<String>,
}
impl GeneralConfig {
pub fn get_public_url(&self) -> Url {
self.public_url.parse().expect("failed to parse general.public-url")
}
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct AuthConfig { pub struct AuthConfig {
pub disable_registration: bool, pub disable_registration: bool,
pub openid_enabled: Option<bool>, pub oidc: Option<OidcConfig>,
pub openid_issuer_url: Option<String>, }
pub openid_client_id: Option<String>,
pub openid_client_secret: Option<String>
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct OidcConfig {
#[serde(default)]
pub enabled: bool,
pub issuer_url: String,
pub client_id: String,
pub client_secret: String,
#[serde(default)]
pub claims: Vec<String>,
#[serde(default)]
pub create_account: bool,
#[serde(default)]
pub disable_normal_login: bool
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SmtpEncryption {
None,
StartTls,
Tls
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct EmailConfig { pub struct EmailConfig {
#[serde(default)]
} pub enabled: bool,
pub hostname: String,
pub port: u16,
pub username: String,
pub password: String,
pub tls: Option<SmtpEncryption>,
pub from_name: Option<String>,
pub from_email: Option<String>,
}

View file

@ -4,6 +4,7 @@ use std::sync::{LazyLock, OnceLock};
use std::time::Duration; use std::time::Duration;
use rocket::data::ByteUnit; use rocket::data::ByteUnit;
use rocket::serde::Serialize; use rocket::serde::Serialize;
use crate::GlobalMetadata;
/// The maximum amount of bytes that can be uploaded at once /// The maximum amount of bytes that can be uploaded at once
pub const MAX_UPLOAD_SIZE: ByteUnit = ByteUnit::Mebibyte(100_000); pub const MAX_UPLOAD_SIZE: ByteUnit = ByteUnit::Mebibyte(100_000);
@ -32,5 +33,12 @@ pub const FILE_CONSTANTS: FileConstants = FileConstants {
pub static DISABLE_LOGIN_CHECK: LazyLock<bool> = LazyLock::new(|| { pub static DISABLE_LOGIN_CHECK: LazyLock<bool> = LazyLock::new(|| {
env::var("DANGER_DISABLE_LOGIN_CHECKS").is_ok() env::var("DANGER_DISABLE_LOGIN_CHECKS").is_ok()
}); });
pub static APP_METADATA: LazyLock<GlobalMetadata> = LazyLock::new(|| {
GlobalMetadata {
app_name: env!("CARGO_PKG_NAME").to_string(),
app_version: env!("CARGO_PKG_VERSION").to_string(),
repo_url: env!("CARGO_PKG_REPOSITORY").to_string(),
}
});
pub fn init_statics() { pub fn init_statics() {
} }

View file

@ -1,4 +1,4 @@
use std::net::IpAddr; use std::net::{IpAddr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
@ -22,9 +22,11 @@ use tracing_subscriber::fmt::writer::MakeWriterExt;
use crate::managers::libraries::LibraryManager; use crate::managers::libraries::LibraryManager;
use crate::managers::repos::RepoManager; use crate::managers::repos::RepoManager;
use crate::objs::library::Library; use crate::objs::library::Library;
use crate::util::{setup_logger, JsonErrorResponse, ResponseError}; use crate::util::{setup_db, setup_logger, setup_session_store, JsonErrorResponse, ResponseError};
use routes::api; use routes::api;
use crate::config::{get_settings, AppConfig};
use crate::consts::{init_statics, SESSION_COOKIE_NAME, SESSION_LIFETIME_SECONDS}; use crate::consts::{init_statics, SESSION_COOKIE_NAME, SESSION_LIFETIME_SECONDS};
use crate::managers::sso::{SSOState, SSO};
use crate::models::user::UserModel; use crate::models::user::UserModel;
use crate::routes::ui; use crate::routes::ui;
@ -78,16 +80,25 @@ async fn rocket() -> _ {
warn!("warn"); warn!("warn");
error!("error"); error!("error");
// TODO: move to own fn let settings: AppConfig = get_settings();
let pool = PgPoolOptions::new() info!("Auth | Registration={} Login={} | OIDC={} CreateAccount={}",
.max_connections(5) if settings.auth.disable_registration { "N" } else { "Y" },
.connect(std::env::var("DATABASE_URL").unwrap().as_str()) settings.auth.oidc.as_ref().map(|oidc| if oidc.disable_normal_login { "N" } else { "Y" } ).unwrap_or("Y"),
.await settings.auth.oidc.as_ref().map(|oidc| if oidc.enabled { "Y" } else { "N" } ).unwrap_or("N"),
.unwrap(); settings.auth.oidc.as_ref().map(|oidc| if oidc.create_account { "Y" } else { "N" }).unwrap_or("-"),
);
migrate!("./migrations") let listen_ip: IpAddr = settings.general.listen_ip.as_ref()
.run(&pool) .map(|s| s.to_string())
.await.unwrap(); .unwrap_or_else(||"0.0.0.0".to_string())
.parse().expect("bad listen ip");
let listen_addr = SocketAddr::new(listen_ip, settings.general.listen_port.unwrap_or(8080));
info!("Listening on {} | Public URL: {}", listen_addr, settings.general.public_url);
if let Some(ref smtp) = settings.smtp {
if smtp.enabled {
info!("SMTP Enabled");
}
}
let pool = setup_db().await;
let repo_manager = { let repo_manager = {
let mut manager = RepoManager::new(pool.clone()); let mut manager = RepoManager::new(pool.clone());
@ -100,32 +111,21 @@ async fn rocket() -> _ {
}; };
// TODO: move to own func // TODO: move to own func
let memory_store: MemoryStore::<SessionData> = MemoryStore::default(); let store = setup_session_store();
let store: SessionStore<SessionData> = SessionStore { let sso: SSOState = {
store: Box::new(memory_store), if settings.auth.oidc.is_some() { Some(Arc::new(Mutex::new(SSO::create(&settings).await)) ) } else { None }
name: SESSION_COOKIE_NAME.into(),
duration: Duration::from_secs(SESSION_LIFETIME_SECONDS),
// The cookie builder is used to set the cookie's path and other options.
// Name and value don't matter, they'll be overridden on each request.
cookie_builder: CookieBuilder::new("", "")
// Most web apps will want to use "/", but if your app is served from
// `example.com/myapp/` for example you may want to use "/myapp/" (note the trailing
// slash which prevents the cookie from being sent for `example.com/myapp2/`).
.path("/")
}; };
// TODO: move to constants let figment = rocket::Config::figment()
let metadata = GlobalMetadata { .merge(("port", listen_addr.port()))
app_name: env!("CARGO_PKG_NAME").to_string(), .merge(("address", listen_addr.ip()));
app_version: env!("CARGO_PKG_VERSION").to_string(),
repo_url: env!("CARGO_PKG_REPOSITORY").to_string(),
};
rocket::build() rocket::custom(figment)
.manage(pool) .manage(pool)
.manage(repo_manager) .manage(repo_manager)
.manage(libraries_manager) .manage(libraries_manager)
.manage(metadata) .manage(settings)
.manage(sso)
.attach(store.fairing()) .attach(store.fairing())
.attach(Template::custom(|engines| { .attach(Template::custom(|engines| {

View file

@ -1,2 +1,3 @@
pub mod repos; pub mod repos;
pub mod libraries; pub mod libraries;
pub mod sso;

148
src/managers/sso.rs Normal file
View file

@ -0,0 +1,148 @@
use std::env::var;
use std::net::IpAddr;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use anyhow::anyhow;
use log::warn;
use moka::future::Cache;
use openidconnect::core::{CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenResponse};
use openidconnect::http::{HeaderMap, HeaderValue};
use openidconnect::{Client, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce, ProviderMetadata, RedirectUrl, Scope, StandardErrorResponse};
use openidconnect::url::ParseError;
use rocket::yansi::Paint;
use tokio::sync::Mutex;
use crate::config::{AppConfig, OidcConfig};
pub struct SSO {
http_client: reqwest::Client,
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: Option<ClientSecret>,
public_url: String,
scopes: Vec<String>,
cache: Cache<IpAddr, SSOSessionData>,
}
pub struct HttpProxySettings {
url: String,
disable_cert_check: bool
}
#[derive(Clone)]
pub struct SSOSessionData {
pub pkce_challenge: String,
pub nonce: Nonce,
pub csrf_token: CsrfToken,
pub return_to: Option<String>
// ip: IpAddr,
}
pub type SSOState = Option<Arc<Mutex<SSO>>>;
impl SSO {
pub async fn create(config: &AppConfig) -> Self {
let oidc_config = config.auth.oidc.as_ref().expect("OIDC config not provided");
let referer = config.general.get_public_url().domain().map(|s| s.to_string());
let http_client = SSO::setup_http_client(referer, None);
let issuer_url = IssuerUrl::new(oidc_config.issuer_url.to_string()).expect("bad issuer url");
let client_id = ClientId::new(oidc_config.client_id.to_string());
let client_secret = Some(ClientSecret::new(oidc_config.client_secret.to_string()));
let cache = Self::setup_cache();
Self {
http_client,
issuer_url,
client_id,
client_secret,
cache,
scopes: oidc_config.claims.to_owned(),
public_url: config.general.public_url.to_string(),
}
}
fn setup_cache() -> Cache<IpAddr, SSOSessionData> {
Cache::builder()
.time_to_live(Duration::from_secs(120))
.max_capacity(100)
.build()
}
fn setup_http_client(referer: Option<String>, proxy_settings: Option<HttpProxySettings>) -> reqwest::Client {
let mut headers = HeaderMap::new();
// TODO: pull from config.
// Set referrer as some providers (authentik) block POST w/o referrer
if let Some(ref referer) = referer {
headers.insert("Referer", referer.parse().expect("bad referer"));
}
let mut builder = reqwest::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.default_headers(headers);
if let Some(proxy) = proxy_settings {
warn!("DANGER_DEV_PROXY set, requests are being proxied & ignoring certificates");
builder = builder
.proxy(reqwest::Proxy::https(proxy.url).unwrap())
.danger_accept_invalid_certs(proxy.disable_cert_check);
};
builder.build().expect("Client should build")
}
pub fn http_client(&self) -> &reqwest::Client {
&self.http_client
}
pub async fn create_client(&self) -> Result<OidcClient, anyhow::Error> {
let provider_metadata = CoreProviderMetadata::discover_async(
/* TODO: pull from config */
self.issuer_url.clone(),
&self.http_client,
).await.map_err(|e| anyhow!(e.to_string()))?;
Ok(CoreClient::from_provider_metadata(
provider_metadata,
// TODO: pull from config
self.client_id.clone(),
self.client_secret.clone(),
))
}
pub async fn create_client_redirect(&self) -> Result<OidcClient, anyhow::Error> {
let redirect_url = RedirectUrl::new( format!("{}/auth/sso/cb", self.public_url))
.map_err(|e: ParseError | anyhow!(e))?;
let client = self.create_client().await?;
Ok(client.set_redirect_uri(redirect_url))
}
pub fn scopes(&self) -> Vec<Scope> {
self.scopes.iter().map(|c| Scope::new(c.to_string())).collect()
}
pub async fn cache_set(&mut self, ip: IpAddr, data: SSOSessionData) {
self.cache.insert(ip, data).await;
}
pub async fn cache_take(&mut self, ip: IpAddr) -> Option<SSOSessionData> {
self.cache.remove(&ip).await
}
}
// From https://github.com/IgnisDa/ryot/blob/75a1379f743b412df0e42fc88177d18cd34d48d7/crates/utils/application/src/lib.rs#L141C31-L148C2
pub type OidcClient<
HasAuthUrl = EndpointSet,
HasDeviceAuthUrl = EndpointNotSet,
HasIntrospectionUrl = EndpointNotSet,
HasRevocationUrl = EndpointNotSet,
HasTokenUrl = EndpointMaybeSet,
HasUserInfoUrl = EndpointMaybeSet,
> = Client<
EmptyAdditionalClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
CoreTokenResponse,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
HasAuthUrl,
HasDeviceAuthUrl,
HasIntrospectionUrl,
HasRevocationUrl,
HasTokenUrl,
HasUserInfoUrl,
>;

View file

@ -22,9 +22,17 @@ pub mod register;
pub mod sso; pub mod sso;
#[derive(Responder)]
#[response(status = 302)]
struct HackyRedirectBecauseRocketBug {
inner: String,
location: Header<'static>,
}
#[get("/logout")] #[get("/logout")]
pub async fn logout(session: Session<'_, SessionData>, user: AuthUser) -> Redirect { pub async fn logout(session: Session<'_, SessionData>, user: AuthUser) -> Redirect {
session.remove().await.unwrap(); session.remove().await.unwrap();
Redirect::to(uri!(login::page(_, Some(true)))) Redirect::to(uri!(login::page(_, Some(true))))
} }

View file

@ -3,13 +3,13 @@ use rocket::form::{Context, Contextual, Form};
use rocket_dyn_templates::{context, Template}; use rocket_dyn_templates::{context, Template};
use rocket_session_store::Session; use rocket_session_store::Session;
use crate::{GlobalMetadata, SessionData}; use crate::{GlobalMetadata, SessionData};
use crate::consts::APP_METADATA;
use crate::util::set_csrf; use crate::util::set_csrf;
#[get("/auth/forgot-password?<return_to>")] #[get("/auth/forgot-password?<return_to>")]
pub async fn page( pub async fn page(
route: &Route, route: &Route,
session: Session<'_, SessionData>, session: Session<'_, SessionData>,
meta: &State<GlobalMetadata>,
return_to: Option<String>, return_to: Option<String>,
) -> Template { ) -> Template {
// TODO: redirect if already logged in // TODO: redirect if already logged in
@ -19,7 +19,7 @@ pub async fn page(
csrf_token: csrf_token, csrf_token: csrf_token,
form: &Context::default(), form: &Context::default(),
return_to, return_to,
meta: meta.inner() meta: APP_METADATA.clone()
}) })
} }

View file

@ -6,15 +6,15 @@ use rocket::http::{Header, Status};
use rocket_dyn_templates::{context, Template}; use rocket_dyn_templates::{context, Template};
use rocket_session_store::Session; use rocket_session_store::Session;
use crate::{GlobalMetadata, LoginSessionData, SessionData, DB}; use crate::{GlobalMetadata, LoginSessionData, SessionData, DB};
use crate::consts::DISABLE_LOGIN_CHECK; use crate::consts::{APP_METADATA, DISABLE_LOGIN_CHECK};
use crate::models::user::validate_user_form; use crate::models::user::validate_user_form;
use crate::routes::ui::auth::HackyRedirectBecauseRocketBug;
use crate::util::{set_csrf, validate_csrf_form}; use crate::util::{set_csrf, validate_csrf_form};
#[get("/auth/login?<return_to>&<logged_out>")] #[get("/auth/login?<return_to>&<logged_out>")]
pub async fn page( pub async fn page(
route: &Route, route: &Route,
session: Session<'_, SessionData>, session: Session<'_, SessionData>,
meta: &State<GlobalMetadata>,
return_to: Option<String>, return_to: Option<String>,
logged_out: Option<bool> logged_out: Option<bool>
) -> Template { ) -> Template {
@ -26,7 +26,7 @@ pub async fn page(
form: &Context::default(), form: &Context::default(),
return_to, return_to,
logged_out, logged_out,
meta: meta.inner() meta: APP_METADATA.clone()
}) })
} }
@ -43,20 +43,12 @@ struct LoginForm<'r> {
} }
#[derive(Responder)]
#[response(status = 302)]
struct HackyRedirectBecauseRocketBug {
inner: String,
location: Header<'static>,
}
#[post("/auth/login?<return_to>", data = "<form>")] #[post("/auth/login?<return_to>", data = "<form>")]
pub async fn handler( pub async fn handler(
pool: &State<DB>, pool: &State<DB>,
route: &Route, route: &Route,
ip_addr: IpAddr, ip_addr: IpAddr,
session: Session<'_, SessionData>, session: Session<'_, SessionData>,
meta: &State<GlobalMetadata>,
mut form: Form<Contextual<'_, LoginForm<'_>>>, mut form: Form<Contextual<'_, LoginForm<'_>>>,
return_to: Option<String>, return_to: Option<String>,
) -> Result<HackyRedirectBecauseRocketBug, Template> { ) -> Result<HackyRedirectBecauseRocketBug, Template> {
@ -95,7 +87,7 @@ pub async fn handler(
csrf_token: csrf_token, csrf_token: csrf_token,
form: &form.context, form: &form.context,
return_to, return_to,
meta: meta.inner() meta: APP_METADATA.clone()
}; };
Err(Template::render("auth/login", &ctx)) Err(Template::render("auth/login", &ctx))
} }

View file

@ -2,15 +2,16 @@ use rocket::{get, post, Route, State};
use rocket_dyn_templates::{context, Template}; use rocket_dyn_templates::{context, Template};
use rocket_session_store::Session; use rocket_session_store::Session;
use crate::{GlobalMetadata, SessionData}; use crate::{GlobalMetadata, SessionData};
use crate::consts::APP_METADATA;
use crate::util::set_csrf; use crate::util::set_csrf;
#[get("/auth/register")] #[get("/auth/register")]
pub async fn page(route: &Route, session: Session<'_, SessionData>, meta: &State<GlobalMetadata>) -> Template { pub async fn page(route: &Route, session: Session<'_, SessionData>) -> Template {
let csrf_token = set_csrf(&session).await; let csrf_token = set_csrf(&session).await;
Template::render("auth/register", context! { Template::render("auth/register", context! {
route: route.uri.path(), route: route.uri.path(),
csrf_token: csrf_token, csrf_token: csrf_token,
meta: meta.inner() meta: APP_METADATA.clone()
}) })
} }

View file

@ -2,10 +2,10 @@ use std::env::var;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::{LazyLock, OnceLock}; use std::sync::{LazyLock, OnceLock};
use std::time::Duration; use std::time::Duration;
use anyhow::anyhow; use anyhow::{anyhow, Error};
use log::warn; use log::{debug, warn};
use moka::future::Cache; use moka::future::Cache;
use rocket::{get, post, uri}; use rocket::{get, post, uri, State};
use rocket::response::Redirect; use rocket::response::Redirect;
use rocket_session_store::Session; use rocket_session_store::Session;
use crate::guards::AuthUser; use crate::guards::AuthUser;
@ -14,60 +14,16 @@ use openidconnect::{reqwest, AccessTokenHash, AsyncHttpClient, AuthenticationFlo
use openidconnect::core::{CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreTokenResponse, CoreUserInfoClaims}; use openidconnect::core::{CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreTokenResponse, CoreUserInfoClaims};
use openidconnect::http::HeaderValue; use openidconnect::http::HeaderValue;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
// TODO: not have this lazy somehow, move to OnceLock and have fn to refresh it? (own module?) use rocket::http::{Header, Status};
// and/or also move to State<> use rocket_dyn_templates::{context, Template};
use tokio::sync::MutexGuard;
use crate::managers::sso::{SSOSessionData, SSOState, SSO};
use crate::routes::ui::auth::HackyRedirectBecauseRocketBug;
static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| { async fn page_handler(sso: &State<SSOState>, ip: IpAddr, return_to: Option<String>) -> Result<Redirect, anyhow::Error> {
let mut headers = HeaderMap::new(); let mut sso = sso.as_ref().ok_or_else(|| anyhow!("SSO is not configured"))?.lock().await;
// TODO: pull from config. let client = sso.create_client_redirect().await?;
// Set referrer as some providers (authentik) block POST w/o referrer
headers.insert("Referer", HeaderValue::from_static("http://localhost:8080"));
let mut builder = reqwest::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.default_headers(headers);
if var("DANGER_DEV_PROXY").is_ok() {
warn!("DANGER_DEV_PROXY set, requests are being proxied & ignoring certificates");
builder = builder
.proxy(reqwest::Proxy::https("https://localhost:8082").unwrap())
.danger_accept_invalid_certs(true)
};
builder.build().expect("Client should build")
});
#[derive(Clone)]
struct SSOSessionData {
pkce_challenge: String,
nonce: Nonce,
csrf_token: CsrfToken
// ip: IpAddr,
}
static SSO_SESSION_CACHE: LazyLock<Cache<IpAddr, SSOSessionData>> = LazyLock::new(|| Cache::builder()
.time_to_live(Duration::from_secs(120))
.max_capacity(100)
.build());
#[get("/auth/sso")]
pub async fn page(ip: IpAddr) -> Redirect {
let http_client = HTTP_CLIENT.clone();
// FIXME: temp, remove
let provider_metadata = CoreProviderMetadata::discover_async(
/* TODO: pull from config */
IssuerUrl::new(var("SSO_ISSUER_URL").expect("dev: missing sso url")).expect("bad issuer url"),
&http_client,
).await.map_err(|e| e.to_string()).expect("discovery failed");
let client =
CoreClient::from_provider_metadata(
provider_metadata,
// TODO: pull from config
ClientId::new(var("SSO_CLIENT_ID").expect("dev: sso client id missing")),
Some(ClientSecret::new(var("SSO_CLIENT_SECRET").expect("dev sso client secret missing")
.to_string())),
).set_redirect_uri(RedirectUrl::new("http://localhost:8080/auth/sso/cb".to_string()).unwrap());
// Generate a PKCE challenge.
// TODO: store in hashmap for request ip? leaky bucket?
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
// Generate the full authorization URL.
let (auth_url, csrf_token, nonce) = client let (auth_url, csrf_token, nonce) = client
.authorize_url( .authorize_url(
CoreAuthenticationFlow::AuthorizationCode, CoreAuthenticationFlow::AuthorizationCode,
@ -75,92 +31,83 @@ pub async fn page(ip: IpAddr) -> Redirect {
Nonce::new_random, Nonce::new_random,
) )
// Set the desired scopes. // Set the desired scopes.
// TODO: change scopes .add_scopes(sso.scopes())
.add_scope(Scope::new("email".to_string()))
.add_scope(Scope::new("name".to_string()))
// Set the PKCE code challenge. // Set the PKCE code challenge.
.set_pkce_challenge(pkce_challenge) .set_pkce_challenge(pkce_challenge)
.url(); .url();
SSO_SESSION_CACHE.insert(ip, SSOSessionData { sso.cache_set(ip, SSOSessionData {
nonce: nonce, nonce: nonce,
pkce_challenge: pkce_verifier.into_secret(), pkce_challenge: pkce_verifier.into_secret(),
csrf_token csrf_token,
return_to
}).await; }).await;
Ok(Redirect::to(auth_url.to_string()))
Redirect::to(auth_url.to_string()) }
#[get("/auth/sso?<return_to>")]
// This is the URL you should redirect the user to, in order to trigger the authorization pub async fn page(ip: IpAddr, sso: &State<SSOState>, return_to: Option<String>) -> Result<Redirect, (Status, Template)> {
// process. page_handler(sso, ip, return_to).await
.map_err(|e| (Status::InternalServerError, Template::render("errors/500", context! {
error: e.to_string()
})))
} }
#[get("/auth/sso/cb?<code>&<state>")] async fn callback_handler(sso: &State<SSOState>, ip: IpAddr, code: String, state: String) -> Result<(CoreUserInfoClaims, Option<String>), anyhow::Error> {
pub async fn callback(session: Session<'_, SessionData>, ip: IpAddr, code: String, state: String) -> Result<String, String> { let mut sso = sso.as_ref().ok_or_else(||anyhow!("SSO is not configured"))?.lock().await;
let session_data = SSO_SESSION_CACHE.remove(&ip).await.ok_or_else(|| "no sso session started".to_string())?; let sess_data = sso.cache_take(ip).await.ok_or_else(|| anyhow!("No valid sso started"))?;
// Now you can exchange it for an access token and ID token. if &state != sess_data.csrf_token.secret() {
if &state != session_data.csrf_token.secret() { return Err(anyhow!("CSRF verification failed"));
return Err(format!("csrf validation failed {}", state));
} }
let client = sso.create_client_redirect().await?;
// FIXME: temp, remove
let http_client = HTTP_CLIENT.clone();
let provider_metadata = CoreProviderMetadata::discover_async(
/* TODO: pull from config */
IssuerUrl::new(var("SSO_ISSUER_URL").expect("dev: missing sso url")).expect("bad issuer url"),
&http_client,
).await.expect("discovery failed");
let client =
CoreClient::from_provider_metadata(
provider_metadata,
// TODO: pull from config
ClientId::new(var("SSO_CLIENT_ID").expect("dev: sso client id missing")),
Some(ClientSecret::new(var("SSO_CLIENT_SECRET").expect("dev sso client secret missing")
.to_string())),
).set_redirect_uri(RedirectUrl::new("http://localhost:8080/auth/sso/cb".to_string()).unwrap());
let token_response = let token_response =
client client
.exchange_code(AuthorizationCode::new(code)).expect("bad auth code") .exchange_code(AuthorizationCode::new(code)).map_err(|e| anyhow!("oidc code is invalid"))?
// Set the PKCE code verifier. // Set the PKCE code verifier.
.set_pkce_verifier(PkceCodeVerifier::new(session_data.pkce_challenge)) // TODO: somehow have this?? .set_pkce_verifier(PkceCodeVerifier::new(sess_data.pkce_challenge)) // TODO: somehow have this??
.request_async(&http_client).await.expect("token exchange error"); .request_async(sso.http_client()).await
.map_err(|e| anyhow!("OIDC Token exchange error"))?;
// Extract the ID token claims after verifying its authenticity and nonce. // Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response let id_token = token_response
.id_token() .id_token()
.ok_or_else(|| "Server did not return an ID token".to_string())?; .ok_or_else(|| anyhow!("Server did not return an ID token"))?;
let id_token_verifier = client.id_token_verifier(); let id_token_verifier = client.id_token_verifier();
let claims = id_token.claims(&id_token_verifier, &session_data.nonce).expect("bad claims"); // TODO: and this? let claims = id_token.claims(&id_token_verifier, &sess_data.nonce).map_err(|e| anyhow!("OIDC Token claims error: {}", e))?;
// Verify the access token hash to ensure that the access token hasn't been substituted for // Verify the access token hash to ensure that the access token hasn't been substituted for another user's.
// another user's.
if let Some(expected_access_token_hash) = claims.access_token_hash() { if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token( let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(), token_response.access_token(),
id_token.signing_alg().expect("signing failed (alg)"), id_token.signing_alg().map_err(|e| anyhow!("OIDC token signature error: {}", e))?,
id_token.signing_key(&id_token_verifier).expect("signing failed (key)"), id_token.signing_key(&id_token_verifier).map_err(|e| anyhow!("OIDC token signature error: {}", e))?
).expect("access token resolve error"); ).expect("access token resolve error");
if actual_access_token_hash != *expected_access_token_hash { if actual_access_token_hash != *expected_access_token_hash {
return Err("Invalid access token".to_string()); return Err(anyhow!("Invalid access token"))
} }
} }
// The authenticated user's identity is now available. See the IdTokenClaims struct for a
// complete listing of the available claims.
println!(
"User {} with e-mail address {} has authenticated successfully",
claims.subject().as_str(),
claims.email().map(|email| email.as_str()).unwrap_or("<not provided>"),
);
// If available, we can use the user info endpoint to request additional information. // If available, we can use the user info endpoint to request additional information.
// The user_info request uses the AccessToken returned in the token response. To parse custom // The user_info request uses the AccessToken returned in the token response. To parse custom
// claims, use UserInfoClaims directly (with the desired type parameters) rather than using the // claims, use UserInfoClaims directly (with the desired type parameters) rather than using the
// CoreUserInfoClaims type alias. // CoreUserInfoClaims type alias.
let userinfo: CoreUserInfoClaims = client let userinfo: CoreUserInfoClaims = client
.user_info(token_response.access_token().to_owned(), None).expect("user info missing") .user_info(token_response.access_token().to_owned(), None).map_err(|_| anyhow!("could not acquire user data"))?
.request_async(&http_client) .request_async(sso.http_client())
.await .await
.map_err(|err| format!("Failed requesting user info: {}", err))?; .map_err(|_| anyhow!("could not acquire user data"))?;
Ok(format!("user={:?}\nemail={:?}\nname={:?}", userinfo.subject(), userinfo.email(), userinfo.name())) Ok((userinfo, sess_data.return_to))
}
#[get("/auth/sso/cb?<code>&<state>")]
pub async fn callback(session: Session<'_, SessionData>, ip: IpAddr, sso: &State<SSOState>, code: String, state: String) -> Result<HackyRedirectBecauseRocketBug, (Status, Template)> {
let (userinfo, return_to) = callback_handler(sso, ip, code, state).await
.map_err(|e| (Status::InternalServerError, Template::render("errors/500", context! {
error: e.to_string()
})))?;
debug!("user={:?}\nemail={:?}\nname={:?}", userinfo.subject(), userinfo.email(), userinfo.name());
let return_to = return_to.unwrap_or("/".to_string());
Ok(HackyRedirectBecauseRocketBug {
inner: "Login successful, redirecting...".to_string(),
location: Header::new("Location", return_to),
})
} }

View file

@ -5,10 +5,11 @@ use rocket_session_store::{Session, SessionResult};
use serde::Serialize; use serde::Serialize;
use crate::models::user::UserModel; use crate::models::user::UserModel;
use crate::{GlobalMetadata, SessionData}; use crate::{GlobalMetadata, SessionData};
use crate::consts::APP_METADATA;
#[get("/help/about")] #[get("/help/about")]
pub fn about(route: &Route, meta: &State<GlobalMetadata>) -> Template { pub fn about(route: &Route) -> Template {
Template::render("about", context! { route: route.uri.path(), meta: meta.inner() }) Template::render("about", context! { route: route.uri.path(), meta: APP_METADATA.clone() })
} }

View file

@ -1,5 +1,6 @@
use std::fs; use std::fs;
use std::io::Cursor; use std::io::Cursor;
use std::time::Duration;
use log::trace; use log::trace;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rand::{rng, Rng, TryRngCore}; use rand::{rng, Rng, TryRngCore};
@ -9,14 +10,18 @@ use rocket::{form, response, Request, Response};
use rocket::form::Context; use rocket::form::Context;
use rocket::form::error::Entity; use rocket::form::error::Entity;
use rocket::fs::relative; use rocket::fs::relative;
use rocket::http::private::cookie::CookieBuilder;
use rocket::response::Responder; use rocket::response::Responder;
use rocket::serde::Serialize; use rocket::serde::Serialize;
use rocket_dyn_templates::handlebars::Handlebars; use rocket_dyn_templates::handlebars::Handlebars;
use rocket_session_store::{Session, SessionError, SessionResult}; use rocket_session_store::{Session, SessionError, SessionResult, SessionStore};
use sqlx::Error; use rocket_session_store::memory::MemoryStore;
use sqlx::{migrate, Error, Pool, Postgres};
use sqlx::postgres::PgPoolOptions;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use uuid::Uuid; use uuid::Uuid;
use crate::consts::{SESSION_COOKIE_NAME, SESSION_LIFETIME_SECONDS};
use crate::models::user::{UserAuthError,}; use crate::models::user::{UserAuthError,};
use crate::SessionData; use crate::SessionData;
use crate::util::ResponseError::DatabaseError; use crate::util::ResponseError::DatabaseError;
@ -31,6 +36,35 @@ pub(crate) fn setup_logger() {
.init(); .init();
} }
pub async fn setup_db() -> Pool<Postgres> {
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(std::env::var("DATABASE_URL").unwrap().as_str())
.await
.unwrap();
migrate!("./migrations")
.run(&pool)
.await.unwrap();
pool
}
pub fn setup_session_store() -> SessionStore<SessionData> {
let memory_store: MemoryStore::<SessionData> = MemoryStore::default();
SessionStore {
store: Box::new(memory_store),
name: SESSION_COOKIE_NAME.into(),
duration: Duration::from_secs(SESSION_LIFETIME_SECONDS),
// The cookie builder is used to set the cookie's path and other options.
// Name and value don't matter, they'll be overridden on each request.
cookie_builder: CookieBuilder::new("", "")
// Most web apps will want to use "/", but if your app is served from
// `example.com/myapp/` for example you may want to use "/myapp/" (note the trailing
// slash which prevents the cookie from being sent for `example.com/myapp2/`).
.path("/")
}
}
pub async fn set_csrf(session: &Session<'_, SessionData>) -> String { pub async fn set_csrf(session: &Session<'_, SessionData>) -> String {
let token = gen_csrf_token(); let token = gen_csrf_token();
trace!("set_csrf token={}", token); trace!("set_csrf token={}", token);
@ -72,26 +106,6 @@ pub fn gen_csrf_token() -> String {
.collect() .collect()
} }
// pub(crate) fn setup_template_engine() -> Handlebars<'static> {
// let mut hb = Handlebars::new();
// #[cfg(debug_assertions)]
// hb.set_dev_mode(true);
//
// let templates = fs::read_dir(relative!("templates")).unwrap();
// let mut ok = true;
// for file in templates {
// let file = file.unwrap();
// if let Err(e) = hb.register_template_file(file.path().to_str().unwrap(), ) {
// error!(template, path = %path.display(),
// "failed to register Handlebars template: {e}");
//
// ok = false;
// }
// }
//
// hb
// }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct JsonErrorResponse { pub struct JsonErrorResponse {
pub(crate) code: String, pub(crate) code: String,

View file

@ -0,0 +1,27 @@
{{#> layouts/default body-class="has-background-white-ter login-bg" }}
<br><br>
<div class="container py-6" style="width:20%"> <!-- TODO: fix width on mobile -->
<h1 class="title is-1 has-text-centered">storage-app</h1>
<div class="box is-radiusless">
<h4 class="title is-4 has-text-centered">500 Internal Server Error</h4>
<p>An internal error occurred while procesing your request</p>
<p><b>Error: </b><code>{{ error }}</code></p>
<br>
<!-- Hide go back unless javascript enabled -->
<p><span id="backlink" style="display:none"><a href="">Go Back</a> | </span><a href="/">Return home</a></p>
</div>
</div>
{{/layouts/default}}
<script>
// Enable 'go back' link:
const element = document.querySelector('#backlink');
element.style.display = "inline"
const elementLink = document.querySelector('#backlink a');
elementLink.setAttribute('href', document.referrer);
elementLink.onclick = function() {
history.back();
return false;
}
</script>