From 60813de8cb21f836e4942cbe0c12520db642fe9d Mon Sep 17 00:00:00 2001 From: Jackz Date: Sun, 20 Apr 2025 15:03:24 -0500 Subject: [PATCH 1/4] very rough SSO code --- Cargo.lock | 383 +++++++++++++++++++++++++++++++++++- Cargo.toml | 4 +- config.sample.toml | 5 + src/consts.rs | 7 +- src/main.rs | 1 + src/models/user.rs | 2 +- src/routes/ui/auth.rs | 2 + src/routes/ui/auth/login.rs | 2 +- src/routes/ui/auth/sso.rs | 155 +++++++++++++++ 9 files changed, 546 insertions(+), 15 deletions(-) create mode 100644 src/routes/ui/auth/sso.rs diff --git a/Cargo.lock b/Cargo.lock index 599f814..8cfa55e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,6 +53,17 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "async-lock" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" +dependencies = [ + "event-listener", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -110,6 +121,12 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.4.0" @@ -304,6 +321,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -343,6 +370,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-queue" version = "0.3.12" @@ -646,6 +682,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -717,6 +763,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -793,6 +854,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -814,6 +886,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -832,7 +905,20 @@ dependencies = [ "libc", "log", "rustversion", - "windows", + "windows 0.48.0", +] + +[[package]] +name = "generator" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bd114ceda131d3b1d665eba35788690ad37f5916457286b32ab6fd3c438dd" +dependencies = [ + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.58.0", ] [[package]] @@ -915,6 +1001,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.3.1", + "indexmap 2.9.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "handlebars" version = "5.1.2" @@ -1093,7 +1198,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -1116,6 +1221,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2 0.4.9", "http 1.3.1", "http-body 1.0.1", "httparse", @@ -1144,6 +1250,22 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.11" @@ -1176,7 +1298,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core", + "windows-core 0.61.0", ] [[package]] @@ -1541,7 +1663,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff50ecb28bb86013e935fb6683ab1f6d3a20016f123c76fd4c27470076ac30f5" dependencies = [ "cfg-if", - "generator", + "generator 0.7.5", "scoped-tls", "serde", "serde_json", @@ -1549,6 +1671,19 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator 0.8.4", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1612,6 +1747,28 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "moka" +version = "0.12.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "event-listener", + "futures-util", + "loom 0.7.2", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "thiserror 1.0.69", + "uuid", +] + [[package]] name = "multer" version = "3.1.0" @@ -1631,6 +1788,23 @@ dependencies = [ "version_check", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "normpath" version = "1.3.0" @@ -1798,6 +1972,50 @@ dependencies = [ "url", ] +[[package]] +name = "openssl" +version = "0.10.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" +dependencies = [ + "bitflags 2.9.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "ordered-float" version = "2.10.1" @@ -1988,6 +2206,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + [[package]] name = "powerfmt" version = "0.2.0" @@ -2244,18 +2468,22 @@ checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", + "h2 0.4.9", "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper 1.6.0", "hyper-rustls", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -2267,7 +2495,9 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper", + "system-configuration", "tokio", + "tokio-native-tls", "tokio-rustls", "tower", "tower-service", @@ -2531,6 +2761,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -2557,6 +2796,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.9.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.26" @@ -3005,7 +3267,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b8c4a4445d81357df8b1a650d0d0d6fbbbfe99d064aa5e02f3e4022061476d8" dependencies = [ - "loom", + "loom 0.5.6", ] [[package]] @@ -3025,8 +3287,10 @@ dependencies = [ "humanize-bytes", "int-enum", "log", + "moka", "openidconnect", "rand 0.9.0", + "reqwest", "rocket", "rocket-session-store", "rocket_dyn_templates", @@ -3092,6 +3356,33 @@ dependencies = [ "syn", ] +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.9.0", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tempfile" version = "3.19.1" @@ -3240,6 +3531,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.2" @@ -3715,19 +4016,53 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core 0.58.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement 0.58.0", + "windows-interface 0.58.0", + "windows-result 0.2.0", + "windows-strings 0.1.0", + "windows-targets 0.52.6", +] + [[package]] name = "windows-core" version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" dependencies = [ - "windows-implement", - "windows-interface", + "windows-implement 0.60.0", + "windows-interface 0.59.1", "windows-link", - "windows-result", + "windows-result 0.3.2", "windows-strings 0.4.0", ] +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-implement" version = "0.60.0" @@ -3739,6 +4074,17 @@ dependencies = [ "syn", ] +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-interface" version = "0.59.1" @@ -3762,11 +4108,20 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ - "windows-result", + "windows-result 0.3.2", "windows-strings 0.3.1", "windows-targets 0.53.0", ] +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-result" version = "0.3.2" @@ -3776,6 +4131,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result 0.2.0", + "windows-targets 0.52.6", +] + [[package]] name = "windows-strings" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index d585bd5..cccb443 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,6 @@ rocket-session-store = "0.2.1" uuid = { version = "1.16.0", features = ["v4"] } rand = { version = "0.9.0", features = ["thread_rng"] } bcrypt = "0.17.0" -openidconnect = "4.0.0" \ No newline at end of file +openidconnect = "4.0.0" +reqwest = "0.12.15" +moka = { version = "0.12.10", features = ["future"] } diff --git a/config.sample.toml b/config.sample.toml index 51e1eba..0de2414 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -1,6 +1,11 @@ [general] listen_ip = "0.0.0.0" listen_port = 80 +# if under reverse proxy +#public_url = "https://storage.example.com" +#public_port = 443 +public_url = "http://localhost:8080" +public_port = 80 [backends.local] path = "/var/tmp/test" diff --git a/src/consts.rs b/src/consts.rs index 3f33246..61c6f7d 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -1,6 +1,6 @@ use std::cell::OnceCell; use std::env; -use std::sync::OnceLock; +use std::sync::{LazyLock, OnceLock}; use std::time::Duration; use rocket::data::ByteUnit; use rocket::serde::Serialize; @@ -29,7 +29,8 @@ pub const FILE_CONSTANTS: FileConstants = FileConstants { /// Disables CSRF & password verification for login /// Used for development due to no session persistence -pub static DISABLE_LOGIN_CHECK: OnceLock = OnceLock::new(); +pub static DISABLE_LOGIN_CHECK: LazyLock = LazyLock::new(|| { + env::var("DANGER_DISABLE_LOGIN_CHECKS").is_ok() +}); pub fn init_statics() { - DISABLE_LOGIN_CHECK.set(env::var("DANGER_DISABLE_LOGIN_CHECKS").is_ok()).unwrap(); } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index d4475d9..65c6e48 100644 --- a/src/main.rs +++ b/src/main.rs @@ -144,6 +144,7 @@ async fn rocket() -> _ { .mount("/", routes![ ui::auth::logout, ui::auth::login::page, ui::auth::login::handler, ui::auth::register::page, ui::auth::register::handler, + ui::auth::sso::page, ui::auth::sso::callback, ui::auth::forgot_password::page, ui::auth::forgot_password::handler, ]) .mount("/", routes![ diff --git a/src/models/user.rs b/src/models/user.rs index b9fd904..de413c8 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -124,7 +124,7 @@ pub async fn validate_user(pool: &DB, email_or_usrname: &str, password: &str) -> return Err(UserAuthError::UserNotFound); }; if let Some(db_password) = user.password { - if !DISABLE_LOGIN_CHECK.get().unwrap() || bcrypt::verify(password, &db_password).map_err(|e| UserAuthError::EncryptionError(e))? { + if !*DISABLE_LOGIN_CHECK || bcrypt::verify(password, &db_password).map_err(|e| UserAuthError::EncryptionError(e))? { return Ok(UserModel { id: user.id, email: user.email, diff --git a/src/routes/ui/auth.rs b/src/routes/ui/auth.rs index 6404591..5467faf 100644 --- a/src/routes/ui/auth.rs +++ b/src/routes/ui/auth.rs @@ -20,6 +20,8 @@ pub mod forgot_password; pub mod login; pub mod register; +pub mod sso; + #[get("/logout")] pub async fn logout(session: Session<'_, SessionData>, user: AuthUser) -> Redirect { session.remove().await.unwrap(); diff --git a/src/routes/ui/auth/login.rs b/src/routes/ui/auth/login.rs index 045a1f0..dc07445 100644 --- a/src/routes/ui/auth/login.rs +++ b/src/routes/ui/auth/login.rs @@ -61,7 +61,7 @@ pub async fn handler( return_to: Option, ) -> Result { trace!("handler"); - if !DISABLE_LOGIN_CHECK.get().unwrap() { + if !*DISABLE_LOGIN_CHECK { validate_csrf_form(&mut form.context, &session).await; } let user = validate_user_form(&mut form.context, &pool).await; diff --git a/src/routes/ui/auth/sso.rs b/src/routes/ui/auth/sso.rs new file mode 100644 index 0000000..c515425 --- /dev/null +++ b/src/routes/ui/auth/sso.rs @@ -0,0 +1,155 @@ +use std::env::var; +use std::net::IpAddr; +use std::sync::{LazyLock, OnceLock}; +use std::time::Duration; +use anyhow::anyhow; +use moka::future::Cache; +use rocket::{get, post, uri}; +use rocket::response::Redirect; +use rocket_session_store::Session; +use crate::guards::AuthUser; +use crate::SessionData; +use openidconnect::{reqwest, AccessTokenHash, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, HttpClientError, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, ProviderMetadata, RedirectUrl, Scope, StandardErrorResponse, TokenResponse}; +use openidconnect::core::{CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreTokenResponse, CoreUserInfoClaims}; + +// TODO: not have this lazy somehow, move to OnceLock and have fn to refresh it? (own module?) +// and/or also move to State<> + +static HTTP_CLIENT: LazyLock = LazyLock::new(|| { + reqwest::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build") +}); +#[derive(Clone)] +struct SSOSessionData { + pkce_challenge: String, + nonce: Nonce, + csrf_token: CsrfToken + // ip: IpAddr, +} +static SSO_SESSION_CACHE: LazyLock> = LazyLock::new(|| Cache::builder() + .time_to_live(Duration::from_secs(120)) + .max_capacity(100) + .build()); +#[get("/auth/sso")] +pub async fn page(session: Session<'_, SessionData>, ip: IpAddr) -> Redirect { + let s = session.get().await.unwrap().unwrap(); + let http_client = HTTP_CLIENT.clone(); + // FIXME: temp, remove + let provider_metadata = CoreProviderMetadata::discover_async( + /* TODO: pull from config */ + IssuerUrl::new(var("SSO_ISSUER_URL").expect("dev: missing sso url")).expect("bad issuer url"), + &http_client, + ).await.expect("discovery failed"); + let client = + CoreClient::from_provider_metadata( + provider_metadata, + // TODO: pull from config + ClientId::new(var("SSO_CLIENT_ID").expect("dev: sso client id missing")), + Some(ClientSecret::new(var("SSO_CLIENT_SECRET").expect("dev sso client secret missing") + .to_string())), + ).set_redirect_uri(RedirectUrl::new("http://localhost:8080/auth/sso/cb".to_string()).unwrap()); + + // Generate a PKCE challenge. + // TODO: store in hashmap for request ip? leaky bucket? + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + // Generate the full authorization URL. + let (auth_url, csrf_token, nonce) = client + .authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ) + // Set the desired scopes. + // TODO: change scopes + .add_scope(Scope::new("read".to_string())) + .add_scope(Scope::new("write".to_string())) + // Set the PKCE code challenge. + .set_pkce_challenge(pkce_challenge) + .url(); + SSO_SESSION_CACHE.insert(ip, SSOSessionData { + nonce: nonce, + pkce_challenge: pkce_verifier.into_secret(), + csrf_token + }).await; + + Redirect::to(auth_url.to_string()) + + // This is the URL you should redirect the user to, in order to trigger the authorization + // process. +} + +#[post("/auth/sso/cb?&")] +pub async fn callback(session: Session<'_, SessionData>, ip: IpAddr, code: String, state: String) -> Result { + let session_data = SSO_SESSION_CACHE.remove(&ip).await.ok_or_else(|| "no sso session started".to_string())?; + // Now you can exchange it for an access token and ID token. + if &state != session_data.csrf_token.secret() { + return Err(format!("csrf validation failed {}", state)); + } + + // FIXME: temp, remove + let http_client = HTTP_CLIENT.clone(); + let provider_metadata = CoreProviderMetadata::discover_async( + /* TODO: pull from config */ + IssuerUrl::new(var("SSO_ISSUER_URL").expect("dev: missing sso url")).expect("bad issuer url"), + &http_client, + ).await.expect("discovery failed"); + let client = + CoreClient::from_provider_metadata( + provider_metadata, + // TODO: pull from config + ClientId::new(var("SSO_CLIENT_ID").expect("dev: sso client id missing")), + Some(ClientSecret::new(var("SSO_CLIENT_SECRET").expect("dev sso client secret missing") + .to_string())), + ).set_redirect_uri(RedirectUrl::new("http://localhost:8080/auth/sso/cb".to_string()).unwrap()); + + let token_response = + client + .exchange_code(AuthorizationCode::new(code)).expect("bad auth code") + // Set the PKCE code verifier. + .set_pkce_verifier(PkceCodeVerifier::new(session_data.pkce_challenge)) // TODO: somehow have this?? + .request_async(&http_client).await.expect("token exchange error"); + + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response + .id_token() + .ok_or_else(|| "Server did not return an ID token".to_string())?; + let id_token_verifier = client.id_token_verifier(); + let claims = id_token.claims(&id_token_verifier, &session_data.nonce).expect("bad claims"); // TODO: and this? + + // Verify the access token hash to ensure that the access token hasn't been substituted for + // another user's. + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = AccessTokenHash::from_token( + token_response.access_token(), + id_token.signing_alg().expect("signing failed (alg)"), + id_token.signing_key(&id_token_verifier).expect("signing failed (key)"), + ).expect("access token resolve error"); + if actual_access_token_hash != *expected_access_token_hash { + return Err("Invalid access token".to_string()); + } + } + + // The authenticated user's identity is now available. See the IdTokenClaims struct for a + // complete listing of the available claims. + println!( + "User {} with e-mail address {} has authenticated successfully", + claims.subject().as_str(), + claims.email().map(|email| email.as_str()).unwrap_or(""), + ); + + // If available, we can use the user info endpoint to request additional information. + + // The user_info request uses the AccessToken returned in the token response. To parse custom + // claims, use UserInfoClaims directly (with the desired type parameters) rather than using the + // CoreUserInfoClaims type alias. + let userinfo: CoreUserInfoClaims = client + .user_info(token_response.access_token().to_owned(), None).expect("user info missing") + .request_async(&http_client) + .await + .map_err(|err| format!("Failed requesting user info: {}", err))?; + Ok(format!("user={:?}\nemail={:?}\nname={:?}", userinfo.subject(), userinfo.email(), userinfo.name())) +} From 97424ca524b03c57d2936d03ec7c628a06e536b7 Mon Sep 17 00:00:00 2001 From: Jackz Date: Sun, 20 Apr 2025 16:08:31 -0500 Subject: [PATCH 2/4] Working SSO test --- config.sample.toml | 2 ++ src/routes/ui/auth/sso.rs | 33 ++++++++++++++++++++++----------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/config.sample.toml b/config.sample.toml index 0de2414..1a42fe1 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -17,5 +17,7 @@ openid_enabled = true openid_issuer_url = "https://accounts.example.com" openid_client_id = "" openid_client_secret = "" +openid_claims = [] + [smtp] # TODO: diff --git a/src/routes/ui/auth/sso.rs b/src/routes/ui/auth/sso.rs index c515425..0012109 100644 --- a/src/routes/ui/auth/sso.rs +++ b/src/routes/ui/auth/sso.rs @@ -3,24 +3,36 @@ use std::net::IpAddr; use std::sync::{LazyLock, OnceLock}; use std::time::Duration; use anyhow::anyhow; +use log::warn; use moka::future::Cache; use rocket::{get, post, uri}; use rocket::response::Redirect; use rocket_session_store::Session; use crate::guards::AuthUser; use crate::SessionData; -use openidconnect::{reqwest, AccessTokenHash, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, HttpClientError, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, ProviderMetadata, RedirectUrl, Scope, StandardErrorResponse, TokenResponse}; +use openidconnect::{reqwest, AccessTokenHash, AsyncHttpClient, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, HttpClientError, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, ProviderMetadata, RedirectUrl, Scope, StandardErrorResponse, TokenResponse}; use openidconnect::core::{CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreTokenResponse, CoreUserInfoClaims}; - +use openidconnect::http::HeaderValue; +use reqwest::header::HeaderMap; // TODO: not have this lazy somehow, move to OnceLock and have fn to refresh it? (own module?) // and/or also move to State<> static HTTP_CLIENT: LazyLock = LazyLock::new(|| { - reqwest::ClientBuilder::new() + let mut headers = HeaderMap::new(); + // TODO: pull from config. + // Set referrer as some providers (authentik) block POST w/o referrer + headers.insert("Referer", HeaderValue::from_static("http://localhost:8080")); + let mut builder = reqwest::ClientBuilder::new() // Following redirects opens the client up to SSRF vulnerabilities. .redirect(reqwest::redirect::Policy::none()) - .build() - .expect("Client should build") + .default_headers(headers); + if var("DANGER_DEV_PROXY").is_ok() { + warn!("DANGER_DEV_PROXY set, requests are being proxied & ignoring certificates"); + builder = builder + .proxy(reqwest::Proxy::https("https://localhost:8082").unwrap()) + .danger_accept_invalid_certs(true) + }; + builder.build().expect("Client should build") }); #[derive(Clone)] struct SSOSessionData { @@ -34,15 +46,14 @@ static SSO_SESSION_CACHE: LazyLock> = LazyLock::ne .max_capacity(100) .build()); #[get("/auth/sso")] -pub async fn page(session: Session<'_, SessionData>, ip: IpAddr) -> Redirect { - let s = session.get().await.unwrap().unwrap(); +pub async fn page(ip: IpAddr) -> Redirect { let http_client = HTTP_CLIENT.clone(); // FIXME: temp, remove let provider_metadata = CoreProviderMetadata::discover_async( /* TODO: pull from config */ IssuerUrl::new(var("SSO_ISSUER_URL").expect("dev: missing sso url")).expect("bad issuer url"), &http_client, - ).await.expect("discovery failed"); + ).await.map_err(|e| e.to_string()).expect("discovery failed"); let client = CoreClient::from_provider_metadata( provider_metadata, @@ -65,8 +76,8 @@ pub async fn page(session: Session<'_, SessionData>, ip: IpAddr) -> Redirect { ) // Set the desired scopes. // TODO: change scopes - .add_scope(Scope::new("read".to_string())) - .add_scope(Scope::new("write".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("name".to_string())) // Set the PKCE code challenge. .set_pkce_challenge(pkce_challenge) .url(); @@ -82,7 +93,7 @@ pub async fn page(session: Session<'_, SessionData>, ip: IpAddr) -> Redirect { // process. } -#[post("/auth/sso/cb?&")] +#[get("/auth/sso/cb?&")] pub async fn callback(session: Session<'_, SessionData>, ip: IpAddr, code: String, state: String) -> Result { let session_data = SSO_SESSION_CACHE.remove(&ip).await.ok_or_else(|| "no sso session started".to_string())?; // Now you can exchange it for an access token and ID token. From cab39de312f28edcd8fe42daa8481418c5e27fb4 Mon Sep 17 00:00:00 2001 From: Jackz Date: Sun, 20 Apr 2025 16:10:57 -0500 Subject: [PATCH 3/4] Change config --- config.sample.toml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/config.sample.toml b/config.sample.toml index 1a42fe1..9d25b42 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -12,12 +12,14 @@ path = "/var/tmp/test" [auth] enable_registration = true -openid_enabled = true +oidc_enabled = true # Where the .well-known/openid-configuration exists -openid_issuer_url = "https://accounts.example.com" -openid_client_id = "" -openid_client_secret = "" -openid_claims = [] +oidc_issuer_url = "https://accounts.example.com" +oidc_client_id = "" +oidc_client_secret = "" +oidc_claims = [] +# Should an account be created if SSO user id doesn't exist already +oidc_create_account = true [smtp] # TODO: From e13f080d912430c8a5808a34f5ee4a47b36cc21e Mon Sep 17 00:00:00 2001 From: Jackz Date: Sun, 20 Apr 2025 20:56:33 -0500 Subject: [PATCH 4/4] implement config + Pull SSO data from config --- .env.sample | 3 +- Cargo.lock | 1 + Cargo.toml | 1 + config.sample.toml | 42 +++++-- src/config.rs | 89 ++++++++++++-- src/consts.rs | 8 ++ src/main.rs | 64 +++++----- src/managers.rs | 3 +- src/managers/sso.rs | 148 +++++++++++++++++++++++ src/routes/ui/auth.rs | 8 ++ src/routes/ui/auth/forgot_password.rs | 4 +- src/routes/ui/auth/login.rs | 16 +-- src/routes/ui/auth/register.rs | 5 +- src/routes/ui/auth/sso.rs | 167 +++++++++----------------- src/routes/ui/help.rs | 5 +- src/util.rs | 58 +++++---- templates/errors/500.html.hbs | 27 +++++ 17 files changed, 440 insertions(+), 209 deletions(-) create mode 100644 src/managers/sso.rs create mode 100644 templates/errors/500.html.hbs diff --git a/.env.sample b/.env.sample index a18472d..ada1a2c 100644 --- a/.env.sample +++ b/.env.sample @@ -1 +1,2 @@ -DATABASE_URL=postgresql://server:5432/database?user=user&password=password&connectTimeout=30¤tSchema=storage; \ No newline at end of file +DATABASE_URL=postgresql://server:5432/database?user=user&password=password&connectTimeout=30¤tSchema=storage; +STORAGE_AUTH_OIDC_ISSUER_URL = "" \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8cfa55e..f4bdb66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3284,6 +3284,7 @@ dependencies = [ "bcrypt", "chrono", "dotenvy", + "figment", "humanize-bytes", "int-enum", "log", diff --git a/Cargo.toml b/Cargo.toml index cccb443..61551a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,4 @@ bcrypt = "0.17.0" openidconnect = "4.0.0" reqwest = "0.12.15" moka = { version = "0.12.10", features = ["future"] } +figment = "0.10.19" diff --git a/config.sample.toml b/config.sample.toml index 9d25b42..478d66e 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -1,25 +1,41 @@ [general] listen_ip = "0.0.0.0" -listen_port = 80 -# if under reverse proxy +listen_port = 8080 + +# The public facing url, this is where users will access the app +# Used for OIDC callbacks +# - if under reverse proxy (nginx, traefik, caddy, etc): #public_url = "https://storage.example.com" -#public_port = 443 public_url = "http://localhost:8080" -public_port = 80 [backends.local] path = "/var/tmp/test" [auth] -enable_registration = true -oidc_enabled = true -# Where the .well-known/openid-configuration exists -oidc_issuer_url = "https://accounts.example.com" -oidc_client_id = "" -oidc_client_secret = "" -oidc_claims = [] +# Is account registration disabled? Users will not be able to create +# a new account with email/username + pass +disable_registration = false +[auth.oidc] +enabled = true +# The url the .well-known/openid-configuration exists, this can be a subpath +# Example, for authentik: https://sso.example.com/application/o/YOURAPPSLUG +issuer_url = "" +client_id = "" +client_secret = "" +claims = ["email", "profile"] # Should an account be created if SSO user id doesn't exist already -oidc_create_account = true +create_account = true +# Should normal login (username/email+pass) be disabled, forcing users to use sso? +disable_normal_login = false [smtp] -# TODO: +enabled = false +hostname = "smtp.example.com" +port = 587 +username = "" +password = "" +# Name to be used for emails, defaults to public_url's domain +#from_name = "" +# The email address to send as, defaults to username +#from_address = "" +tls = "none" # "none", "starttls" or "tls" diff --git a/src/config.rs b/src/config.rs index 2a1cc2c..7705a22 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,27 +1,92 @@ -use rocket::serde::{Serialize,Deserialize}; +use std::collections::HashMap; +use std::env::var; +use figment::Figment; +use figment::providers::{Env, Format, Toml}; +use log::error; +use openidconnect::core::{CoreClient, CoreProviderMetadata}; +use openidconnect::IssuerUrl; +use openidconnect::url::Url; +use rocket::serde::{Serialize, Deserialize}; #[derive(Debug, Serialize, Deserialize)] -pub struct Config { - general: GeneralConfig, - auth: AuthConfig, - smtp: EmailConfig +#[serde(rename_all = "kebab-case")] +pub struct AppConfig { + pub general: GeneralConfig, + pub auth: AuthConfig, + pub smtp: Option } +pub fn get_settings() -> AppConfig { + let f = Figment::new() + .merge(Toml::file("config.toml")) + .merge(Env::prefixed("STORAGE_") + .map(|f| f.as_str().replace("__", "-").into())) + .extract(); + match f { + Ok(settings) => settings, + Err(e) => { + error!("Failed to read configuration"); + error!("{}", e); + std::process::exit(1); + } + } +} + + + + #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] pub struct GeneralConfig { pub listen_ip: Option, - pub listen_port: Option + pub listen_port: Option, + pub public_url: String, + pub database_url: Option, +} +impl GeneralConfig { + pub fn get_public_url(&self) -> Url { + self.public_url.parse().expect("failed to parse general.public-url") + } } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] pub struct AuthConfig { pub disable_registration: bool, - pub openid_enabled: Option, - pub openid_issuer_url: Option, - pub openid_client_id: Option, - pub openid_client_secret: Option + pub oidc: Option, +} +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "kebab-case")] +pub struct OidcConfig { + #[serde(default)] + pub enabled: bool, + pub issuer_url: String, + pub client_id: String, + pub client_secret: String, + #[serde(default)] + pub claims: Vec, + #[serde(default)] + pub create_account: bool, + #[serde(default)] + pub disable_normal_login: bool } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SmtpEncryption { + None, + StartTls, + Tls +} +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] pub struct EmailConfig { - -} \ No newline at end of file + #[serde(default)] + pub enabled: bool, + pub hostname: String, + pub port: u16, + pub username: String, + pub password: String, + pub tls: Option, + pub from_name: Option, + pub from_email: Option, +} diff --git a/src/consts.rs b/src/consts.rs index 61c6f7d..27b0b8e 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -4,6 +4,7 @@ use std::sync::{LazyLock, OnceLock}; use std::time::Duration; use rocket::data::ByteUnit; use rocket::serde::Serialize; +use crate::GlobalMetadata; /// The maximum amount of bytes that can be uploaded at once pub const MAX_UPLOAD_SIZE: ByteUnit = ByteUnit::Mebibyte(100_000); @@ -32,5 +33,12 @@ pub const FILE_CONSTANTS: FileConstants = FileConstants { pub static DISABLE_LOGIN_CHECK: LazyLock = LazyLock::new(|| { env::var("DANGER_DISABLE_LOGIN_CHECKS").is_ok() }); +pub static APP_METADATA: LazyLock = LazyLock::new(|| { + GlobalMetadata { + app_name: env!("CARGO_PKG_NAME").to_string(), + app_version: env!("CARGO_PKG_VERSION").to_string(), + repo_url: env!("CARGO_PKG_REPOSITORY").to_string(), + } +}); pub fn init_statics() { } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 65c6e48..aa4426b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; use log::{debug, error, info, trace, warn}; @@ -22,9 +22,11 @@ use tracing_subscriber::fmt::writer::MakeWriterExt; use crate::managers::libraries::LibraryManager; use crate::managers::repos::RepoManager; use crate::objs::library::Library; -use crate::util::{setup_logger, JsonErrorResponse, ResponseError}; +use crate::util::{setup_db, setup_logger, setup_session_store, JsonErrorResponse, ResponseError}; use routes::api; +use crate::config::{get_settings, AppConfig}; use crate::consts::{init_statics, SESSION_COOKIE_NAME, SESSION_LIFETIME_SECONDS}; +use crate::managers::sso::{SSOState, SSO}; use crate::models::user::UserModel; use crate::routes::ui; @@ -78,16 +80,25 @@ async fn rocket() -> _ { warn!("warn"); error!("error"); - // TODO: move to own fn - let pool = PgPoolOptions::new() - .max_connections(5) - .connect(std::env::var("DATABASE_URL").unwrap().as_str()) - .await - .unwrap(); - - migrate!("./migrations") - .run(&pool) - .await.unwrap(); + let settings: AppConfig = get_settings(); + info!("Auth | Registration={} Login={} | OIDC={} CreateAccount={}", + if settings.auth.disable_registration { "N" } else { "Y" }, + settings.auth.oidc.as_ref().map(|oidc| if oidc.disable_normal_login { "N" } else { "Y" } ).unwrap_or("Y"), + settings.auth.oidc.as_ref().map(|oidc| if oidc.enabled { "Y" } else { "N" } ).unwrap_or("N"), + settings.auth.oidc.as_ref().map(|oidc| if oidc.create_account { "Y" } else { "N" }).unwrap_or("-"), + ); + let listen_ip: IpAddr = settings.general.listen_ip.as_ref() + .map(|s| s.to_string()) + .unwrap_or_else(||"0.0.0.0".to_string()) + .parse().expect("bad listen ip"); + let listen_addr = SocketAddr::new(listen_ip, settings.general.listen_port.unwrap_or(8080)); + info!("Listening on {} | Public URL: {}", listen_addr, settings.general.public_url); + if let Some(ref smtp) = settings.smtp { + if smtp.enabled { + info!("SMTP Enabled"); + } + } + let pool = setup_db().await; let repo_manager = { let mut manager = RepoManager::new(pool.clone()); @@ -100,32 +111,21 @@ async fn rocket() -> _ { }; // TODO: move to own func - let memory_store: MemoryStore:: = MemoryStore::default(); - let store: SessionStore = SessionStore { - store: Box::new(memory_store), - name: SESSION_COOKIE_NAME.into(), - duration: Duration::from_secs(SESSION_LIFETIME_SECONDS), - // The cookie builder is used to set the cookie's path and other options. - // Name and value don't matter, they'll be overridden on each request. - cookie_builder: CookieBuilder::new("", "") - // Most web apps will want to use "/", but if your app is served from - // `example.com/myapp/` for example you may want to use "/myapp/" (note the trailing - // slash which prevents the cookie from being sent for `example.com/myapp2/`). - .path("/") + let store = setup_session_store(); + let sso: SSOState = { + if settings.auth.oidc.is_some() { Some(Arc::new(Mutex::new(SSO::create(&settings).await)) ) } else { None } }; - // TODO: move to constants - let metadata = GlobalMetadata { - app_name: env!("CARGO_PKG_NAME").to_string(), - app_version: env!("CARGO_PKG_VERSION").to_string(), - repo_url: env!("CARGO_PKG_REPOSITORY").to_string(), - }; + let figment = rocket::Config::figment() + .merge(("port", listen_addr.port())) + .merge(("address", listen_addr.ip())); - rocket::build() + rocket::custom(figment) .manage(pool) .manage(repo_manager) .manage(libraries_manager) - .manage(metadata) + .manage(settings) + .manage(sso) .attach(store.fairing()) .attach(Template::custom(|engines| { diff --git a/src/managers.rs b/src/managers.rs index 6d63f86..a43655c 100644 --- a/src/managers.rs +++ b/src/managers.rs @@ -1,2 +1,3 @@ pub mod repos; -pub mod libraries; \ No newline at end of file +pub mod libraries; +pub mod sso; \ No newline at end of file diff --git a/src/managers/sso.rs b/src/managers/sso.rs new file mode 100644 index 0000000..2d1d51c --- /dev/null +++ b/src/managers/sso.rs @@ -0,0 +1,148 @@ +use std::env::var; +use std::net::IpAddr; +use std::sync::{Arc, LazyLock}; +use std::time::Duration; +use anyhow::anyhow; +use log::warn; +use moka::future::Cache; +use openidconnect::core::{CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenResponse}; +use openidconnect::http::{HeaderMap, HeaderValue}; +use openidconnect::{Client, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce, ProviderMetadata, RedirectUrl, Scope, StandardErrorResponse}; +use openidconnect::url::ParseError; +use rocket::yansi::Paint; +use tokio::sync::Mutex; +use crate::config::{AppConfig, OidcConfig}; + +pub struct SSO { + http_client: reqwest::Client, + issuer_url: IssuerUrl, + client_id: ClientId, + client_secret: Option, + public_url: String, + scopes: Vec, + cache: Cache, +} +pub struct HttpProxySettings { + url: String, + disable_cert_check: bool +} + +#[derive(Clone)] +pub struct SSOSessionData { + pub pkce_challenge: String, + pub nonce: Nonce, + pub csrf_token: CsrfToken, + pub return_to: Option + // ip: IpAddr, +} +pub type SSOState = Option>>; +impl SSO { + pub async fn create(config: &AppConfig) -> Self { + let oidc_config = config.auth.oidc.as_ref().expect("OIDC config not provided"); + let referer = config.general.get_public_url().domain().map(|s| s.to_string()); + let http_client = SSO::setup_http_client(referer, None); + let issuer_url = IssuerUrl::new(oidc_config.issuer_url.to_string()).expect("bad issuer url"); + let client_id = ClientId::new(oidc_config.client_id.to_string()); + let client_secret = Some(ClientSecret::new(oidc_config.client_secret.to_string())); + let cache = Self::setup_cache(); + Self { + http_client, + issuer_url, + client_id, + client_secret, + cache, + scopes: oidc_config.claims.to_owned(), + public_url: config.general.public_url.to_string(), + } + } + + fn setup_cache() -> Cache { + Cache::builder() + .time_to_live(Duration::from_secs(120)) + .max_capacity(100) + .build() + } + + fn setup_http_client(referer: Option, proxy_settings: Option) -> reqwest::Client { + let mut headers = HeaderMap::new(); + // TODO: pull from config. + // Set referrer as some providers (authentik) block POST w/o referrer + if let Some(ref referer) = referer { + headers.insert("Referer", referer.parse().expect("bad referer")); + } + let mut builder = reqwest::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .default_headers(headers); + if let Some(proxy) = proxy_settings { + warn!("DANGER_DEV_PROXY set, requests are being proxied & ignoring certificates"); + builder = builder + .proxy(reqwest::Proxy::https(proxy.url).unwrap()) + .danger_accept_invalid_certs(proxy.disable_cert_check); + }; + builder.build().expect("Client should build") + } + pub fn http_client(&self) -> &reqwest::Client { + &self.http_client + } + + pub async fn create_client(&self) -> Result { + let provider_metadata = CoreProviderMetadata::discover_async( + /* TODO: pull from config */ + self.issuer_url.clone(), + &self.http_client, + ).await.map_err(|e| anyhow!(e.to_string()))?; + Ok(CoreClient::from_provider_metadata( + provider_metadata, + // TODO: pull from config + self.client_id.clone(), + self.client_secret.clone(), + )) + } + + pub async fn create_client_redirect(&self) -> Result { + let redirect_url = RedirectUrl::new( format!("{}/auth/sso/cb", self.public_url)) + .map_err(|e: ParseError | anyhow!(e))?; + let client = self.create_client().await?; + Ok(client.set_redirect_uri(redirect_url)) + } + + pub fn scopes(&self) -> Vec { + self.scopes.iter().map(|c| Scope::new(c.to_string())).collect() + } + + pub async fn cache_set(&mut self, ip: IpAddr, data: SSOSessionData) { + self.cache.insert(ip, data).await; + } + + pub async fn cache_take(&mut self, ip: IpAddr) -> Option { + self.cache.remove(&ip).await + } +} +// From https://github.com/IgnisDa/ryot/blob/75a1379f743b412df0e42fc88177d18cd34d48d7/crates/utils/application/src/lib.rs#L141C31-L148C2 +pub type OidcClient< + HasAuthUrl = EndpointSet, + HasDeviceAuthUrl = EndpointNotSet, + HasIntrospectionUrl = EndpointNotSet, + HasRevocationUrl = EndpointNotSet, + HasTokenUrl = EndpointMaybeSet, + HasUserInfoUrl = EndpointMaybeSet, +> = Client< + EmptyAdditionalClaims, + CoreAuthDisplay, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJsonWebKey, + CoreAuthPrompt, + StandardErrorResponse, + CoreTokenResponse, + CoreTokenIntrospectionResponse, + CoreRevocableToken, + CoreRevocationErrorResponse, + HasAuthUrl, + HasDeviceAuthUrl, + HasIntrospectionUrl, + HasRevocationUrl, + HasTokenUrl, + HasUserInfoUrl, +>; \ No newline at end of file diff --git a/src/routes/ui/auth.rs b/src/routes/ui/auth.rs index 5467faf..906d922 100644 --- a/src/routes/ui/auth.rs +++ b/src/routes/ui/auth.rs @@ -22,9 +22,17 @@ pub mod register; pub mod sso; +#[derive(Responder)] +#[response(status = 302)] +struct HackyRedirectBecauseRocketBug { + inner: String, + location: Header<'static>, +} + #[get("/logout")] pub async fn logout(session: Session<'_, SessionData>, user: AuthUser) -> Redirect { session.remove().await.unwrap(); Redirect::to(uri!(login::page(_, Some(true)))) } + diff --git a/src/routes/ui/auth/forgot_password.rs b/src/routes/ui/auth/forgot_password.rs index 81b3b3d..2811831 100644 --- a/src/routes/ui/auth/forgot_password.rs +++ b/src/routes/ui/auth/forgot_password.rs @@ -3,13 +3,13 @@ use rocket::form::{Context, Contextual, Form}; use rocket_dyn_templates::{context, Template}; use rocket_session_store::Session; use crate::{GlobalMetadata, SessionData}; +use crate::consts::APP_METADATA; use crate::util::set_csrf; #[get("/auth/forgot-password?")] pub async fn page( route: &Route, session: Session<'_, SessionData>, - meta: &State, return_to: Option, ) -> Template { // TODO: redirect if already logged in @@ -19,7 +19,7 @@ pub async fn page( csrf_token: csrf_token, form: &Context::default(), return_to, - meta: meta.inner() + meta: APP_METADATA.clone() }) } diff --git a/src/routes/ui/auth/login.rs b/src/routes/ui/auth/login.rs index dc07445..a7e1568 100644 --- a/src/routes/ui/auth/login.rs +++ b/src/routes/ui/auth/login.rs @@ -6,15 +6,15 @@ use rocket::http::{Header, Status}; use rocket_dyn_templates::{context, Template}; use rocket_session_store::Session; use crate::{GlobalMetadata, LoginSessionData, SessionData, DB}; -use crate::consts::DISABLE_LOGIN_CHECK; +use crate::consts::{APP_METADATA, DISABLE_LOGIN_CHECK}; use crate::models::user::validate_user_form; +use crate::routes::ui::auth::HackyRedirectBecauseRocketBug; use crate::util::{set_csrf, validate_csrf_form}; #[get("/auth/login?&")] pub async fn page( route: &Route, session: Session<'_, SessionData>, - meta: &State, return_to: Option, logged_out: Option ) -> Template { @@ -26,7 +26,7 @@ pub async fn page( form: &Context::default(), return_to, logged_out, - meta: meta.inner() + meta: APP_METADATA.clone() }) } @@ -43,20 +43,12 @@ struct LoginForm<'r> { } -#[derive(Responder)] -#[response(status = 302)] -struct HackyRedirectBecauseRocketBug { - inner: String, - location: Header<'static>, -} - #[post("/auth/login?", data = "
")] pub async fn handler( pool: &State, route: &Route, ip_addr: IpAddr, session: Session<'_, SessionData>, - meta: &State, mut form: Form>>, return_to: Option, ) -> Result { @@ -95,7 +87,7 @@ pub async fn handler( csrf_token: csrf_token, form: &form.context, return_to, - meta: meta.inner() + meta: APP_METADATA.clone() }; Err(Template::render("auth/login", &ctx)) } \ No newline at end of file diff --git a/src/routes/ui/auth/register.rs b/src/routes/ui/auth/register.rs index f870441..83875f0 100644 --- a/src/routes/ui/auth/register.rs +++ b/src/routes/ui/auth/register.rs @@ -2,15 +2,16 @@ use rocket::{get, post, Route, State}; use rocket_dyn_templates::{context, Template}; use rocket_session_store::Session; use crate::{GlobalMetadata, SessionData}; +use crate::consts::APP_METADATA; use crate::util::set_csrf; #[get("/auth/register")] -pub async fn page(route: &Route, session: Session<'_, SessionData>, meta: &State) -> Template { +pub async fn page(route: &Route, session: Session<'_, SessionData>) -> Template { let csrf_token = set_csrf(&session).await; Template::render("auth/register", context! { route: route.uri.path(), csrf_token: csrf_token, - meta: meta.inner() + meta: APP_METADATA.clone() }) } diff --git a/src/routes/ui/auth/sso.rs b/src/routes/ui/auth/sso.rs index 0012109..1f932f8 100644 --- a/src/routes/ui/auth/sso.rs +++ b/src/routes/ui/auth/sso.rs @@ -2,10 +2,10 @@ use std::env::var; use std::net::IpAddr; use std::sync::{LazyLock, OnceLock}; use std::time::Duration; -use anyhow::anyhow; -use log::warn; +use anyhow::{anyhow, Error}; +use log::{debug, warn}; use moka::future::Cache; -use rocket::{get, post, uri}; +use rocket::{get, post, uri, State}; use rocket::response::Redirect; use rocket_session_store::Session; use crate::guards::AuthUser; @@ -14,60 +14,16 @@ use openidconnect::{reqwest, AccessTokenHash, AsyncHttpClient, AuthenticationFlo use openidconnect::core::{CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreTokenResponse, CoreUserInfoClaims}; use openidconnect::http::HeaderValue; use reqwest::header::HeaderMap; -// TODO: not have this lazy somehow, move to OnceLock and have fn to refresh it? (own module?) -// and/or also move to State<> +use rocket::http::{Header, Status}; +use rocket_dyn_templates::{context, Template}; +use tokio::sync::MutexGuard; +use crate::managers::sso::{SSOSessionData, SSOState, SSO}; +use crate::routes::ui::auth::HackyRedirectBecauseRocketBug; -static HTTP_CLIENT: LazyLock = LazyLock::new(|| { - let mut headers = HeaderMap::new(); - // TODO: pull from config. - // Set referrer as some providers (authentik) block POST w/o referrer - headers.insert("Referer", HeaderValue::from_static("http://localhost:8080")); - let mut builder = reqwest::ClientBuilder::new() - // Following redirects opens the client up to SSRF vulnerabilities. - .redirect(reqwest::redirect::Policy::none()) - .default_headers(headers); - if var("DANGER_DEV_PROXY").is_ok() { - warn!("DANGER_DEV_PROXY set, requests are being proxied & ignoring certificates"); - builder = builder - .proxy(reqwest::Proxy::https("https://localhost:8082").unwrap()) - .danger_accept_invalid_certs(true) - }; - builder.build().expect("Client should build") -}); -#[derive(Clone)] -struct SSOSessionData { - pkce_challenge: String, - nonce: Nonce, - csrf_token: CsrfToken - // ip: IpAddr, -} -static SSO_SESSION_CACHE: LazyLock> = LazyLock::new(|| Cache::builder() - .time_to_live(Duration::from_secs(120)) - .max_capacity(100) - .build()); -#[get("/auth/sso")] -pub async fn page(ip: IpAddr) -> Redirect { - let http_client = HTTP_CLIENT.clone(); - // FIXME: temp, remove - let provider_metadata = CoreProviderMetadata::discover_async( - /* TODO: pull from config */ - IssuerUrl::new(var("SSO_ISSUER_URL").expect("dev: missing sso url")).expect("bad issuer url"), - &http_client, - ).await.map_err(|e| e.to_string()).expect("discovery failed"); - let client = - CoreClient::from_provider_metadata( - provider_metadata, - // TODO: pull from config - ClientId::new(var("SSO_CLIENT_ID").expect("dev: sso client id missing")), - Some(ClientSecret::new(var("SSO_CLIENT_SECRET").expect("dev sso client secret missing") - .to_string())), - ).set_redirect_uri(RedirectUrl::new("http://localhost:8080/auth/sso/cb".to_string()).unwrap()); - - // Generate a PKCE challenge. - // TODO: store in hashmap for request ip? leaky bucket? +async fn page_handler(sso: &State, ip: IpAddr, return_to: Option) -> Result { + let mut sso = sso.as_ref().ok_or_else(|| anyhow!("SSO is not configured"))?.lock().await; + let client = sso.create_client_redirect().await?; let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - - // Generate the full authorization URL. let (auth_url, csrf_token, nonce) = client .authorize_url( CoreAuthenticationFlow::AuthorizationCode, @@ -75,92 +31,83 @@ pub async fn page(ip: IpAddr) -> Redirect { Nonce::new_random, ) // Set the desired scopes. - // TODO: change scopes - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("name".to_string())) + .add_scopes(sso.scopes()) // Set the PKCE code challenge. .set_pkce_challenge(pkce_challenge) .url(); - SSO_SESSION_CACHE.insert(ip, SSOSessionData { + sso.cache_set(ip, SSOSessionData { nonce: nonce, pkce_challenge: pkce_verifier.into_secret(), - csrf_token + csrf_token, + return_to }).await; - - Redirect::to(auth_url.to_string()) - - // This is the URL you should redirect the user to, in order to trigger the authorization - // process. + Ok(Redirect::to(auth_url.to_string())) +} +#[get("/auth/sso?")] +pub async fn page(ip: IpAddr, sso: &State, return_to: Option) -> Result { + page_handler(sso, ip, return_to).await + .map_err(|e| (Status::InternalServerError, Template::render("errors/500", context! { + error: e.to_string() + }))) } -#[get("/auth/sso/cb?&")] -pub async fn callback(session: Session<'_, SessionData>, ip: IpAddr, code: String, state: String) -> Result { - let session_data = SSO_SESSION_CACHE.remove(&ip).await.ok_or_else(|| "no sso session started".to_string())?; - // Now you can exchange it for an access token and ID token. - if &state != session_data.csrf_token.secret() { - return Err(format!("csrf validation failed {}", state)); +async fn callback_handler(sso: &State, ip: IpAddr, code: String, state: String) -> Result<(CoreUserInfoClaims, Option), anyhow::Error> { + let mut sso = sso.as_ref().ok_or_else(||anyhow!("SSO is not configured"))?.lock().await; + let sess_data = sso.cache_take(ip).await.ok_or_else(|| anyhow!("No valid sso started"))?; + if &state != sess_data.csrf_token.secret() { + return Err(anyhow!("CSRF verification failed")); } - - // FIXME: temp, remove - let http_client = HTTP_CLIENT.clone(); - let provider_metadata = CoreProviderMetadata::discover_async( - /* TODO: pull from config */ - IssuerUrl::new(var("SSO_ISSUER_URL").expect("dev: missing sso url")).expect("bad issuer url"), - &http_client, - ).await.expect("discovery failed"); - let client = - CoreClient::from_provider_metadata( - provider_metadata, - // TODO: pull from config - ClientId::new(var("SSO_CLIENT_ID").expect("dev: sso client id missing")), - Some(ClientSecret::new(var("SSO_CLIENT_SECRET").expect("dev sso client secret missing") - .to_string())), - ).set_redirect_uri(RedirectUrl::new("http://localhost:8080/auth/sso/cb".to_string()).unwrap()); - + let client = sso.create_client_redirect().await?; let token_response = client - .exchange_code(AuthorizationCode::new(code)).expect("bad auth code") + .exchange_code(AuthorizationCode::new(code)).map_err(|e| anyhow!("oidc code is invalid"))? // Set the PKCE code verifier. - .set_pkce_verifier(PkceCodeVerifier::new(session_data.pkce_challenge)) // TODO: somehow have this?? - .request_async(&http_client).await.expect("token exchange error"); + .set_pkce_verifier(PkceCodeVerifier::new(sess_data.pkce_challenge)) // TODO: somehow have this?? + .request_async(sso.http_client()).await + .map_err(|e| anyhow!("OIDC Token exchange error"))?; // Extract the ID token claims after verifying its authenticity and nonce. let id_token = token_response .id_token() - .ok_or_else(|| "Server did not return an ID token".to_string())?; + .ok_or_else(|| anyhow!("Server did not return an ID token"))?; let id_token_verifier = client.id_token_verifier(); - let claims = id_token.claims(&id_token_verifier, &session_data.nonce).expect("bad claims"); // TODO: and this? + let claims = id_token.claims(&id_token_verifier, &sess_data.nonce).map_err(|e| anyhow!("OIDC Token claims error: {}", e))?; - // Verify the access token hash to ensure that the access token hasn't been substituted for - // another user's. + // Verify the access token hash to ensure that the access token hasn't been substituted for another user's. if let Some(expected_access_token_hash) = claims.access_token_hash() { let actual_access_token_hash = AccessTokenHash::from_token( token_response.access_token(), - id_token.signing_alg().expect("signing failed (alg)"), - id_token.signing_key(&id_token_verifier).expect("signing failed (key)"), + id_token.signing_alg().map_err(|e| anyhow!("OIDC token signature error: {}", e))?, + id_token.signing_key(&id_token_verifier).map_err(|e| anyhow!("OIDC token signature error: {}", e))? ).expect("access token resolve error"); if actual_access_token_hash != *expected_access_token_hash { - return Err("Invalid access token".to_string()); + return Err(anyhow!("Invalid access token")) } } - // The authenticated user's identity is now available. See the IdTokenClaims struct for a - // complete listing of the available claims. - println!( - "User {} with e-mail address {} has authenticated successfully", - claims.subject().as_str(), - claims.email().map(|email| email.as_str()).unwrap_or(""), - ); - // If available, we can use the user info endpoint to request additional information. // The user_info request uses the AccessToken returned in the token response. To parse custom // claims, use UserInfoClaims directly (with the desired type parameters) rather than using the // CoreUserInfoClaims type alias. let userinfo: CoreUserInfoClaims = client - .user_info(token_response.access_token().to_owned(), None).expect("user info missing") - .request_async(&http_client) + .user_info(token_response.access_token().to_owned(), None).map_err(|_| anyhow!("could not acquire user data"))? + .request_async(sso.http_client()) .await - .map_err(|err| format!("Failed requesting user info: {}", err))?; - Ok(format!("user={:?}\nemail={:?}\nname={:?}", userinfo.subject(), userinfo.email(), userinfo.name())) + .map_err(|_| anyhow!("could not acquire user data"))?; + Ok((userinfo, sess_data.return_to)) +} + +#[get("/auth/sso/cb?&")] +pub async fn callback(session: Session<'_, SessionData>, ip: IpAddr, sso: &State, code: String, state: String) -> Result { + let (userinfo, return_to) = callback_handler(sso, ip, code, state).await + .map_err(|e| (Status::InternalServerError, Template::render("errors/500", context! { + error: e.to_string() + })))?; + debug!("user={:?}\nemail={:?}\nname={:?}", userinfo.subject(), userinfo.email(), userinfo.name()); + let return_to = return_to.unwrap_or("/".to_string()); + Ok(HackyRedirectBecauseRocketBug { + inner: "Login successful, redirecting...".to_string(), + location: Header::new("Location", return_to), + }) } diff --git a/src/routes/ui/help.rs b/src/routes/ui/help.rs index 71a79fc..8674121 100644 --- a/src/routes/ui/help.rs +++ b/src/routes/ui/help.rs @@ -5,10 +5,11 @@ use rocket_session_store::{Session, SessionResult}; use serde::Serialize; use crate::models::user::UserModel; use crate::{GlobalMetadata, SessionData}; +use crate::consts::APP_METADATA; #[get("/help/about")] -pub fn about(route: &Route, meta: &State) -> Template { - Template::render("about", context! { route: route.uri.path(), meta: meta.inner() }) +pub fn about(route: &Route) -> Template { + Template::render("about", context! { route: route.uri.path(), meta: APP_METADATA.clone() }) } diff --git a/src/util.rs b/src/util.rs index 4dd0c6a..8910db3 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,5 +1,6 @@ use std::fs; use std::io::Cursor; +use std::time::Duration; use log::trace; use rand::rngs::OsRng; use rand::{rng, Rng, TryRngCore}; @@ -9,14 +10,18 @@ use rocket::{form, response, Request, Response}; use rocket::form::Context; use rocket::form::error::Entity; use rocket::fs::relative; +use rocket::http::private::cookie::CookieBuilder; use rocket::response::Responder; use rocket::serde::Serialize; use rocket_dyn_templates::handlebars::Handlebars; -use rocket_session_store::{Session, SessionError, SessionResult}; -use sqlx::Error; +use rocket_session_store::{Session, SessionError, SessionResult, SessionStore}; +use rocket_session_store::memory::MemoryStore; +use sqlx::{migrate, Error, Pool, Postgres}; +use sqlx::postgres::PgPoolOptions; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use uuid::Uuid; +use crate::consts::{SESSION_COOKIE_NAME, SESSION_LIFETIME_SECONDS}; use crate::models::user::{UserAuthError,}; use crate::SessionData; use crate::util::ResponseError::DatabaseError; @@ -31,6 +36,35 @@ pub(crate) fn setup_logger() { .init(); } +pub async fn setup_db() -> Pool { + let pool = PgPoolOptions::new() + .max_connections(5) + .connect(std::env::var("DATABASE_URL").unwrap().as_str()) + .await + .unwrap(); + + migrate!("./migrations") + .run(&pool) + .await.unwrap(); + pool +} + +pub fn setup_session_store() -> SessionStore { + let memory_store: MemoryStore:: = MemoryStore::default(); + SessionStore { + store: Box::new(memory_store), + name: SESSION_COOKIE_NAME.into(), + duration: Duration::from_secs(SESSION_LIFETIME_SECONDS), + // The cookie builder is used to set the cookie's path and other options. + // Name and value don't matter, they'll be overridden on each request. + cookie_builder: CookieBuilder::new("", "") + // Most web apps will want to use "/", but if your app is served from + // `example.com/myapp/` for example you may want to use "/myapp/" (note the trailing + // slash which prevents the cookie from being sent for `example.com/myapp2/`). + .path("/") + } +} + pub async fn set_csrf(session: &Session<'_, SessionData>) -> String { let token = gen_csrf_token(); trace!("set_csrf token={}", token); @@ -72,26 +106,6 @@ pub fn gen_csrf_token() -> String { .collect() } -// pub(crate) fn setup_template_engine() -> Handlebars<'static> { -// let mut hb = Handlebars::new(); -// #[cfg(debug_assertions)] -// hb.set_dev_mode(true); -// -// let templates = fs::read_dir(relative!("templates")).unwrap(); -// let mut ok = true; -// for file in templates { -// let file = file.unwrap(); -// if let Err(e) = hb.register_template_file(file.path().to_str().unwrap(), ) { -// error!(template, path = %path.display(), -// "failed to register Handlebars template: {e}"); -// -// ok = false; -// } -// } -// -// hb -// } - #[derive(Debug, Clone, Serialize)] pub struct JsonErrorResponse { pub(crate) code: String, diff --git a/templates/errors/500.html.hbs b/templates/errors/500.html.hbs new file mode 100644 index 0000000..1e84d5a --- /dev/null +++ b/templates/errors/500.html.hbs @@ -0,0 +1,27 @@ +{{#> layouts/default body-class="has-background-white-ter login-bg" }} +

+
+

storage-app

+
+

500 Internal Server Error

+

An internal error occurred while procesing your request

+

Error: {{ error }}

+ +
+ +

Return home

+
+
+{{/layouts/default}} + + \ No newline at end of file