mirror of
https://github.com/dani-garcia/vaultwarden
synced 2024-12-14 17:43:46 +01:00
Implemented better errors for JWT
This commit is contained in:
parent
250a2b340f
commit
2bb0b15e04
4 changed files with 37 additions and 41 deletions
|
@ -76,10 +76,8 @@ fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => err!("No valid invite token")
|
None => err!("No valid invite token")
|
||||||
};
|
};
|
||||||
let claims: InviteJWTClaims = match decode_invite_jwt(&token) {
|
|
||||||
Ok(claims) => claims,
|
let claims: InviteJWTClaims = decode_invite_jwt(&token)?;
|
||||||
Err(msg) => err!("Invalid claim: {:#?}", msg),
|
|
||||||
};
|
|
||||||
if &claims.email == &data.Email {
|
if &claims.email == &data.Email {
|
||||||
user
|
user
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -522,10 +522,7 @@ fn accept_invite(_org_id: String, _org_user_id: String, data: JsonUpcase<AcceptD
|
||||||
// The web-vault passes org_id and org_user_id in the URL, but we are just reading them from the JWT instead
|
// The web-vault passes org_id and org_user_id in the URL, but we are just reading them from the JWT instead
|
||||||
let data: AcceptData = data.into_inner().data;
|
let data: AcceptData = data.into_inner().data;
|
||||||
let token = &data.Token;
|
let token = &data.Token;
|
||||||
let claims: InviteJWTClaims = match decode_invite_jwt(&token) {
|
let claims: InviteJWTClaims = decode_invite_jwt(&token)?;
|
||||||
Ok(claims) => claims,
|
|
||||||
Err(msg) => err!("Invalid claim: {:#?}", msg),
|
|
||||||
};
|
|
||||||
|
|
||||||
match User::find_by_mail(&claims.email, &conn) {
|
match User::find_by_mail(&claims.email, &conn) {
|
||||||
Some(_) => {
|
Some(_) => {
|
||||||
|
|
43
src/auth.rs
43
src/auth.rs
|
@ -7,6 +7,7 @@ use chrono::Duration;
|
||||||
use jsonwebtoken::{self, Algorithm, Header};
|
use jsonwebtoken::{self, Algorithm, Header};
|
||||||
use serde::ser::Serialize;
|
use serde::ser::Serialize;
|
||||||
|
|
||||||
|
use crate::error::{Error, MapResult};
|
||||||
use crate::CONFIG;
|
use crate::CONFIG;
|
||||||
|
|
||||||
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
|
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
|
||||||
|
@ -31,11 +32,11 @@ lazy_static! {
|
||||||
pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
|
pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
|
||||||
match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) {
|
match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) {
|
||||||
Ok(token) => token,
|
Ok(token) => token,
|
||||||
Err(e) => panic!("Error encoding jwt {}", e)
|
Err(e) => panic!("Error encoding jwt {}", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode_jwt(token: &str) -> Result<JWTClaims, String> {
|
pub fn decode_jwt(token: &str) -> Result<JWTClaims, Error> {
|
||||||
let validation = jsonwebtoken::Validation {
|
let validation = jsonwebtoken::Validation {
|
||||||
leeway: 30, // 30 seconds
|
leeway: 30, // 30 seconds
|
||||||
validate_exp: true,
|
validate_exp: true,
|
||||||
|
@ -47,16 +48,12 @@ pub fn decode_jwt(token: &str) -> Result<JWTClaims, String> {
|
||||||
algorithms: vec![JWT_ALGORITHM],
|
algorithms: vec![JWT_ALGORITHM],
|
||||||
};
|
};
|
||||||
|
|
||||||
match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) {
|
jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation)
|
||||||
Ok(decoded) => Ok(decoded.claims),
|
.map(|d| d.claims)
|
||||||
Err(msg) => {
|
.map_res("Error decoding login JWT")
|
||||||
error!("Error validating jwt - {:#?}", msg);
|
|
||||||
Err(msg.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode_invite_jwt(token: &str) -> Result<InviteJWTClaims, String> {
|
pub fn decode_invite_jwt(token: &str) -> Result<InviteJWTClaims, Error> {
|
||||||
let validation = jsonwebtoken::Validation {
|
let validation = jsonwebtoken::Validation {
|
||||||
leeway: 30, // 30 seconds
|
leeway: 30, // 30 seconds
|
||||||
validate_exp: true,
|
validate_exp: true,
|
||||||
|
@ -68,13 +65,9 @@ pub fn decode_invite_jwt(token: &str) -> Result<InviteJWTClaims, String> {
|
||||||
algorithms: vec![JWT_ALGORITHM],
|
algorithms: vec![JWT_ALGORITHM],
|
||||||
};
|
};
|
||||||
|
|
||||||
match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) {
|
jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation)
|
||||||
Ok(decoded) => Ok(decoded.claims),
|
.map(|d| d.claims)
|
||||||
Err(msg) => {
|
.map_res("Error decoding invite JWT")
|
||||||
error!("Error validating jwt - {:#?}", msg);
|
|
||||||
Err(msg.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
@ -150,7 +143,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
||||||
CONFIG.domain.clone()
|
CONFIG.domain.clone()
|
||||||
} else if let Some(referer) = headers.get_one("Referer") {
|
} else if let Some(referer) = headers.get_one("Referer") {
|
||||||
referer.to_string()
|
referer.to_string()
|
||||||
} else {
|
} else {
|
||||||
// Try to guess from the headers
|
// Try to guess from the headers
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
|
@ -185,7 +178,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
||||||
// Check JWT token is valid and get device and user from it
|
// Check JWT token is valid and get device and user from it
|
||||||
let claims: JWTClaims = match decode_jwt(access_token) {
|
let claims: JWTClaims = match decode_jwt(access_token) {
|
||||||
Ok(claims) => claims,
|
Ok(claims) => claims,
|
||||||
Err(_) => err_handler!("Invalid claim")
|
Err(_) => err_handler!("Invalid claim"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let device_uuid = claims.device;
|
let device_uuid = claims.device;
|
||||||
|
@ -193,17 +186,17 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
||||||
|
|
||||||
let conn = match request.guard::<DbConn>() {
|
let conn = match request.guard::<DbConn>() {
|
||||||
Outcome::Success(conn) => conn,
|
Outcome::Success(conn) => conn,
|
||||||
_ => err_handler!("Error getting DB")
|
_ => err_handler!("Error getting DB"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let device = match Device::find_by_uuid(&device_uuid, &conn) {
|
let device = match Device::find_by_uuid(&device_uuid, &conn) {
|
||||||
Some(device) => device,
|
Some(device) => device,
|
||||||
None => err_handler!("Invalid device id")
|
None => err_handler!("Invalid device id"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let user = match User::find_by_uuid(&user_uuid, &conn) {
|
let user = match User::find_by_uuid(&user_uuid, &conn) {
|
||||||
Some(user) => user,
|
Some(user) => user,
|
||||||
None => err_handler!("Device has no user associated")
|
None => err_handler!("Device has no user associated"),
|
||||||
};
|
};
|
||||||
|
|
||||||
if user.security_stamp != claims.sstamp {
|
if user.security_stamp != claims.sstamp {
|
||||||
|
@ -248,11 +241,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
|
||||||
None => err_handler!("The current user isn't member of the organization")
|
None => err_handler!("The current user isn't member of the organization")
|
||||||
};
|
};
|
||||||
|
|
||||||
Outcome::Success(Self{
|
Outcome::Success(Self {
|
||||||
host: headers.host,
|
host: headers.host,
|
||||||
device: headers.device,
|
device: headers.device,
|
||||||
user: headers.user,
|
user: headers.user,
|
||||||
org_user_type: {
|
org_user_type: {
|
||||||
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.type_) {
|
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.type_) {
|
||||||
org_usr_type
|
org_usr_type
|
||||||
} else { // This should only happen if the DB is corrupted
|
} else { // This should only happen if the DB is corrupted
|
||||||
|
@ -260,7 +253,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
},
|
}
|
||||||
_ => err_handler!("Error getting the organization id"),
|
_ => err_handler!("Error getting the organization id"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
24
src/error.rs
24
src/error.rs
|
@ -44,14 +44,15 @@ macro_rules! make_error {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
use diesel::result::{Error as DieselError, QueryResult};
|
use diesel::result::Error as DieselError;
|
||||||
use serde_json::{Value, Error as SerError};
|
use jsonwebtoken::errors::Error as JwtError;
|
||||||
|
use serde_json::{Error as SerError, Value};
|
||||||
use u2f::u2ferror::U2fError as U2fErr;
|
use u2f::u2ferror::U2fError as U2fErr;
|
||||||
|
|
||||||
// Error struct
|
// Error struct
|
||||||
// Each variant has two elements, the first is an error of different types, used for logging purposes
|
// Each variant has two elements, the first is an error of different types, used for logging purposes
|
||||||
// The second is a String, and it's contents are displayed to the user when the error occurs. Inside the macro, this is represented as _
|
// The second is a String, and it's contents are displayed to the user when the error occurs. Inside the macro, this is represented as _
|
||||||
//
|
//
|
||||||
// After the variant itself, there are two expressions. The first one is a bool to indicate whether the error cause will be printed to the log.
|
// After the variant itself, there are two expressions. The first one is a bool to indicate whether the error cause will be printed to the log.
|
||||||
// The second one contains the function used to obtain the response sent to the client
|
// The second one contains the function used to obtain the response sent to the client
|
||||||
make_error! {
|
make_error! {
|
||||||
|
@ -63,6 +64,7 @@ make_error! {
|
||||||
DbError(DieselError, _): true, _api_error,
|
DbError(DieselError, _): true, _api_error,
|
||||||
U2fError(U2fErr, _): true, _api_error,
|
U2fError(U2fErr, _): true, _api_error,
|
||||||
SerdeError(SerError, _): true, _api_error,
|
SerdeError(SerError, _): true, _api_error,
|
||||||
|
JWTError(JwtError, _): true, _api_error,
|
||||||
//WsError(ws::Error, _): true, _api_error,
|
//WsError(ws::Error, _): true, _api_error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,19 +75,25 @@ impl Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait MapResult<S, E> {
|
pub trait MapResult<S, E> {
|
||||||
fn map_res(self, msg: &str) -> Result<(), E>;
|
fn map_res(self, msg: &str) -> Result<S, E>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MapResult<(), Error> for QueryResult<usize> {
|
impl<S, E: Into<Error>> MapResult<S, Error> for Result<S, E> {
|
||||||
|
fn map_res(self, msg: &str) -> Result<S, Error> {
|
||||||
|
self.map_err(Into::into).map_err(|e| e.with_msg(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: Into<Error>> MapResult<(), Error> for Result<usize, E> {
|
||||||
fn map_res(self, msg: &str) -> Result<(), Error> {
|
fn map_res(self, msg: &str) -> Result<(), Error> {
|
||||||
self.and(Ok(())).map_err(Error::from).map_err(|e| e.with_msg(msg))
|
self.and(Ok(())).map_res(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::any::Any;
|
use std::any::Any;
|
||||||
|
|
||||||
fn _serialize(e: &impl Serialize, _: &impl Any) -> String {
|
fn _serialize(e: &impl Serialize, _msg: &str) -> String {
|
||||||
serde_json::to_string(e).unwrap()
|
serde_json::to_string(e).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +110,7 @@ fn _api_error(_: &impl Any, msg: &str) -> String {
|
||||||
"Object": "error"
|
"Object": "error"
|
||||||
});
|
});
|
||||||
|
|
||||||
_serialize(&json, &false)
|
_serialize(&json, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
Loading…
Reference in a new issue