mirror of
https://github.com/Jackzmc/storage.git
synced 2025-05-07 14:53:20 +00:00
179 lines
No EOL
5.9 KiB
Rust
179 lines
No EOL
5.9 KiB
Rust
use std::fs;
|
|
use std::io::Cursor;
|
|
use log::trace;
|
|
use rand::rngs::OsRng;
|
|
use rand::{rng, Rng, TryRngCore};
|
|
use rand::distr::Alphanumeric;
|
|
use rocket::http::{ContentType, Status};
|
|
use rocket::{form, response, Request, Response};
|
|
use rocket::form::Context;
|
|
use rocket::form::error::Entity;
|
|
use rocket::fs::relative;
|
|
use rocket::response::Responder;
|
|
use rocket::serde::Serialize;
|
|
use rocket_dyn_templates::handlebars::Handlebars;
|
|
use rocket_session_store::{Session, SessionError, SessionResult};
|
|
use sqlx::Error;
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
use tracing_subscriber::util::SubscriberInitExt;
|
|
use uuid::Uuid;
|
|
use crate::models::user::{UserAuthError,};
|
|
use crate::SessionData;
|
|
use crate::util::ResponseError::DatabaseError;
|
|
|
|
pub(crate) fn setup_logger() {
|
|
tracing_subscriber::registry()
|
|
.with(
|
|
tracing_subscriber::filter::EnvFilter::try_from_default_env()
|
|
.unwrap_or_else(|_| format!("warn,rocket=warn,storage_server=trace").into()),
|
|
)
|
|
.with(tracing_subscriber::fmt::layer())
|
|
.init();
|
|
}
|
|
|
|
pub async fn set_csrf(session: &Session<'_, SessionData>) -> String {
|
|
let token = gen_csrf_token();
|
|
trace!("set_csrf token={}", token);
|
|
let mut sess = session.get().await.expect("failed to get session data")
|
|
.unwrap_or_else(|| SessionData {
|
|
csrf_token: None,
|
|
login: None,
|
|
});
|
|
sess.csrf_token = Some(token.clone());
|
|
session.set(sess).await.unwrap();
|
|
token
|
|
}
|
|
pub(crate) async fn validate_csrf_form(ctx: &mut Context<'_>, session: &Session<'_, SessionData>) -> bool {
|
|
if let Some(form_token) = ctx.field_value("_csrf") {
|
|
if let Some(mut sess) = session.get().await.unwrap() {
|
|
if let Some(sess_token) = sess.csrf_token {
|
|
let success = sess_token == form_token;
|
|
if success {
|
|
sess.csrf_token = None;
|
|
session.set(sess).await.unwrap();
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
ctx.push_error(form::Error::validation("_csrf token invalid").with_entity(Entity::Form));
|
|
return false
|
|
}
|
|
// CSRF failed, set error
|
|
ctx.push_error(rocket::form::Error::validation("CSRF Validation failed").with_entity(Entity::Form));
|
|
false
|
|
}
|
|
|
|
pub fn gen_csrf_token() -> String {
|
|
rng()
|
|
.sample_iter(&Alphanumeric)
|
|
.map(char::from) // map added here
|
|
.take(30)
|
|
.collect()
|
|
}
|
|
|
|
// pub(crate) fn setup_template_engine() -> Handlebars<'static> {
|
|
// let mut hb = Handlebars::new();
|
|
// #[cfg(debug_assertions)]
|
|
// hb.set_dev_mode(true);
|
|
//
|
|
// let templates = fs::read_dir(relative!("templates")).unwrap();
|
|
// let mut ok = true;
|
|
// for file in templates {
|
|
// let file = file.unwrap();
|
|
// if let Err(e) = hb.register_template_file(file.path().to_str().unwrap(), ) {
|
|
// error!(template, path = %path.display(),
|
|
// "failed to register Handlebars template: {e}");
|
|
//
|
|
// ok = false;
|
|
// }
|
|
// }
|
|
//
|
|
// hb
|
|
// }
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct JsonErrorResponse {
|
|
pub(crate) code: String,
|
|
pub(crate) message: String
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum ResponseError {
|
|
NotFound(JsonErrorResponse),
|
|
GenericError,
|
|
InternalServerError(JsonErrorResponse),
|
|
DatabaseError(JsonErrorResponse),
|
|
AuthError(UserAuthError),
|
|
CSRFError
|
|
}
|
|
|
|
impl ResponseError {
|
|
fn get_http_status(&self) -> Status {
|
|
match self {
|
|
ResponseError::InternalServerError(_) => Status::InternalServerError,
|
|
ResponseError::GenericError => Status::InternalServerError,
|
|
ResponseError::NotFound(_) => Status::NotFound,
|
|
ResponseError::DatabaseError(_) => Status::InternalServerError,
|
|
ResponseError::AuthError(e) => e.get_response_code(),
|
|
ResponseError::CSRFError => Status::Unauthorized,
|
|
_ => Status::BadRequest,
|
|
}
|
|
}
|
|
|
|
fn into_res_err(self) -> JsonErrorResponse {
|
|
match self {
|
|
ResponseError::NotFound(e) => e,
|
|
ResponseError::GenericError => {
|
|
JsonErrorResponse {
|
|
code: "INTERNAL_SERVER_ERROR".to_string(),
|
|
message: "An unknown error occurred".to_string(),
|
|
}
|
|
},
|
|
ResponseError::InternalServerError(e) => e,
|
|
DatabaseError(e) => e,
|
|
ResponseError::AuthError(e) => e.into_response_err(),
|
|
ResponseError::CSRFError => {
|
|
JsonErrorResponse {
|
|
code: "CSRF_VALIDATION_FAILED".to_string(),
|
|
message: "CSRF Token is invalid / expired or does not exist. Reload the form and try again".to_string(),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
impl From<sqlx::Error> for ResponseError {
|
|
fn from(value: Error) -> Self {
|
|
let err = value.into_database_error().unwrap();
|
|
DatabaseError(JsonErrorResponse {
|
|
code: err.code().map(|s| s.to_string()).unwrap_or_else(|| "UNKNOWN".to_string()),
|
|
message: err.message().to_string(),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl From<UserAuthError> for ResponseError {
|
|
fn from(value: UserAuthError) -> Self {
|
|
ResponseError::AuthError(value)
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for ResponseError {
|
|
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
|
|
write!(fmt, "Error {}.", self.get_http_status())
|
|
}
|
|
}
|
|
|
|
impl<'r> Responder<'r, 'static> for ResponseError {
|
|
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
|
|
// serialize struct into json string
|
|
let status = self.get_http_status();
|
|
let err_response = serde_json::to_string(&self.into_res_err()).unwrap();
|
|
|
|
Response::build()
|
|
.status(status)
|
|
.header(ContentType::JSON)
|
|
.sized_body(err_response.len(), Cursor::new(err_response))
|
|
.ok()
|
|
}
|
|
} |