Skip to content

Commit

Permalink
feat(users): Decision manager flow changes for SSO (#4995)
Browse files Browse the repository at this point in the history
Co-authored-by: hyperswitch-bot[bot] <148525504+hyperswitch-bot[bot]@users.noreply.github.com>
  • Loading branch information
ThisIsMani and hyperswitch-bot[bot] committed Jun 24, 2024
1 parent 9600461 commit 8ceaaa9
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 17 deletions.
4 changes: 4 additions & 0 deletions crates/common_enums/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2758,6 +2758,10 @@ pub enum BankHolderType {
#[strum(serialize_all = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum TokenPurpose {
AuthSelect,
#[serde(rename = "sso")]
#[strum(serialize = "sso")]
SSO,
#[serde(rename = "totp")]
#[strum(serialize = "totp")]
TOTP,
Expand Down
7 changes: 3 additions & 4 deletions crates/router/src/core/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ pub async fn accept_invite_from_email_token_only_flow(
.map_err(|e| logger::error!(?e));

let current_flow = domain::CurrentFlow::new(
user_token.origin,
user_token,
domain::SPTFlow::AcceptInvitationFromEmail.into(),
)?;
let next_flow = current_flow.next(user_from_db.clone(), &state).await?;
Expand Down Expand Up @@ -1502,8 +1502,7 @@ pub async fn verify_email_token_only_flow(
.await
.map_err(|e| logger::error!(?e));

let current_flow =
domain::CurrentFlow::new(user_token.origin, domain::SPTFlow::VerifyEmail.into())?;
let current_flow = domain::CurrentFlow::new(user_token, domain::SPTFlow::VerifyEmail.into())?;
let next_flow = current_flow.next(user_from_db, &state).await?;
let token = next_flow.get_token(&state).await?;

Expand Down Expand Up @@ -1959,7 +1958,7 @@ pub async fn terminate_two_factor_auth(
}
}

let current_flow = domain::CurrentFlow::new(user_token.origin, domain::SPTFlow::TOTP.into())?;
let current_flow = domain::CurrentFlow::new(user_token, domain::SPTFlow::TOTP.into())?;
let next_flow = current_flow.next(user_from_db, &state).await?;
let token = next_flow.get_token(&state).await?;

Expand Down
2 changes: 1 addition & 1 deletion crates/router/src/core/user_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ pub async fn merchant_select_token_only_flow(
.into();

let current_flow =
domain::CurrentFlow::new(user_token.origin, domain::SPTFlow::MerchantSelect.into())?;
domain::CurrentFlow::new(user_token, domain::SPTFlow::MerchantSelect.into())?;
let next_flow = current_flow.next(user_from_db.clone(), &state).await?;

let token = next_flow
Expand Down
5 changes: 5 additions & 0 deletions crates/router/src/services/authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl AuthenticationType {
pub struct UserFromSinglePurposeToken {
pub user_id: String,
pub origin: domain::Origin,
pub path: Vec<TokenPurpose>,
}

#[cfg(feature = "olap")]
Expand All @@ -132,6 +133,7 @@ pub struct SinglePurposeToken {
pub user_id: String,
pub purpose: TokenPurpose,
pub origin: domain::Origin,
pub path: Vec<TokenPurpose>,
pub exp: u64,
}

Expand All @@ -142,6 +144,7 @@ impl SinglePurposeToken {
purpose: TokenPurpose,
origin: domain::Origin,
settings: &Settings,
path: Vec<TokenPurpose>,
) -> UserResult<String> {
let exp_duration =
std::time::Duration::from_secs(consts::SINGLE_PURPOSE_TOKEN_TIME_IN_SECS);
Expand All @@ -151,6 +154,7 @@ impl SinglePurposeToken {
purpose,
origin,
exp,
path,
};
jwt::generate_jwt(&token_payload, settings).await
}
Expand Down Expand Up @@ -356,6 +360,7 @@ where
UserFromSinglePurposeToken {
user_id: payload.user_id.clone(),
origin: payload.origin.clone(),
path: payload.path,
},
AuthenticationType::SinglePurposeJwt {
user_id: payload.user_id,
Expand Down
1 change: 1 addition & 0 deletions crates/router/src/types/domain/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ impl SignInWithMultipleRolesStrategy {
TokenPurpose::AcceptInvite,
Origin::SignIn,
&state.conf,
vec![],
)
.await?
.into(),
Expand Down
67 changes: 55 additions & 12 deletions crates/router/src/types/domain/user/decision_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@ pub enum UserFlow {
}

impl UserFlow {
async fn is_required(&self, user: &UserFromStorage, state: &SessionState) -> UserResult<bool> {
async fn is_required(
&self,
user: &UserFromStorage,
path: &[TokenPurpose],
state: &SessionState,
) -> UserResult<bool> {
match self {
Self::SPTFlow(flow) => flow.is_required(user, state).await,
Self::SPTFlow(flow) => flow.is_required(user, path, state).await,
Self::JWTFlow(flow) => flow.is_required(user, state).await,
}
}
}

#[derive(Eq, PartialEq, Clone, Copy)]
pub enum SPTFlow {
AuthSelect,
SSO,
TOTP,
VerifyEmail,
AcceptInvitationFromEmail,
Expand All @@ -36,15 +43,26 @@ pub enum SPTFlow {
}

impl SPTFlow {
async fn is_required(&self, user: &UserFromStorage, state: &SessionState) -> UserResult<bool> {
async fn is_required(
&self,
user: &UserFromStorage,
path: &[TokenPurpose],
state: &SessionState,
) -> UserResult<bool> {
match self {
// Auth
// AuthSelect and SSO flow are not enabled, once the terminate SSO API is ready, we can enable these flows
Self::AuthSelect => Ok(false),
Self::SSO => Ok(false),
// TOTP
Self::TOTP => Ok(true),
Self::TOTP => Ok(!path.contains(&TokenPurpose::SSO)),
// Main email APIs
Self::AcceptInvitationFromEmail | Self::ResetPassword => Ok(true),
Self::VerifyEmail => Ok(true),
// Final Checks
Self::ForceSetPassword => user.is_password_rotate_required(state),
Self::ForceSetPassword => user
.is_password_rotate_required(state)
.map(|rotate_required| rotate_required && !path.contains(&TokenPurpose::SSO)),
Self::MerchantSelect => user
.get_roles_from_db(state)
.await
Expand All @@ -62,6 +80,7 @@ impl SPTFlow {
self.into(),
next_flow.origin.clone(),
&state.conf,
next_flow.path.to_vec(),
)
.await
.map(|token| token.into())
Expand Down Expand Up @@ -103,6 +122,8 @@ impl JWTFlow {
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum Origin {
#[serde(rename = "sign_in_with_sso")]
SignInWithSSO,
SignIn,
SignUp,
MagicLink,
Expand All @@ -114,6 +135,7 @@ pub enum Origin {
impl Origin {
fn get_flows(&self) -> &'static [UserFlow] {
match self {
Self::SignInWithSSO => &SIGNIN_WITH_SSO_FLOW,
Self::SignIn => &SIGNIN_FLOW,
Self::SignUp => &SIGNUP_FLOW,
Self::VerifyEmail => &VERIFY_EMAIL_FLOW,
Expand All @@ -124,6 +146,11 @@ impl Origin {
}
}

const SIGNIN_WITH_SSO_FLOW: [UserFlow; 2] = [
UserFlow::SPTFlow(SPTFlow::MerchantSelect),
UserFlow::JWTFlow(JWTFlow::UserInfo),
];

const SIGNIN_FLOW: [UserFlow; 4] = [
UserFlow::SPTFlow(SPTFlow::TOTP),
UserFlow::SPTFlow(SPTFlow::ForceSetPassword),
Expand Down Expand Up @@ -154,7 +181,9 @@ const VERIFY_EMAIL_FLOW: [UserFlow; 5] = [
UserFlow::JWTFlow(JWTFlow::UserInfo),
];

const ACCEPT_INVITATION_FROM_EMAIL_FLOW: [UserFlow; 4] = [
const ACCEPT_INVITATION_FROM_EMAIL_FLOW: [UserFlow; 6] = [
UserFlow::SPTFlow(SPTFlow::AuthSelect),
UserFlow::SPTFlow(SPTFlow::SSO),
UserFlow::SPTFlow(SPTFlow::TOTP),
UserFlow::SPTFlow(SPTFlow::AcceptInvitationFromEmail),
UserFlow::SPTFlow(SPTFlow::ForceSetPassword),
Expand All @@ -169,31 +198,40 @@ const RESET_PASSWORD_FLOW: [UserFlow; 2] = [
pub struct CurrentFlow {
origin: Origin,
current_flow_index: usize,
path: Vec<TokenPurpose>,
}

impl CurrentFlow {
pub fn new(origin: Origin, current_flow: UserFlow) -> UserResult<Self> {
let flows = origin.get_flows();
pub fn new(
token: auth::UserFromSinglePurposeToken,
current_flow: UserFlow,
) -> UserResult<Self> {
let flows = token.origin.get_flows();
let index = flows
.iter()
.position(|flow| flow == &current_flow)
.ok_or(UserErrors::InternalServerError)?;
let mut path = token.path;
path.push(current_flow.into());

Ok(Self {
origin,
origin: token.origin,
current_flow_index: index,
path,
})
}

pub async fn next(&self, user: UserFromStorage, state: &SessionState) -> UserResult<NextFlow> {
pub async fn next(self, user: UserFromStorage, state: &SessionState) -> UserResult<NextFlow> {
let flows = self.origin.get_flows();
let remaining_flows = flows.iter().skip(self.current_flow_index + 1);

for flow in remaining_flows {
if flow.is_required(&user, state).await? {
if flow.is_required(&user, &self.path, state).await? {
return Ok(NextFlow {
origin: self.origin.clone(),
next_flow: *flow,
user,
path: self.path,
});
}
}
Expand All @@ -205,6 +243,7 @@ pub struct NextFlow {
origin: Origin,
next_flow: UserFlow,
user: UserFromStorage,
path: Vec<TokenPurpose>,
}

impl NextFlow {
Expand All @@ -214,12 +253,14 @@ impl NextFlow {
state: &SessionState,
) -> UserResult<Self> {
let flows = origin.get_flows();
let path = vec![];
for flow in flows {
if flow.is_required(&user, state).await? {
if flow.is_required(&user, &path, state).await? {
return Ok(Self {
origin,
next_flow: *flow,
user,
path,
});
}
}
Expand Down Expand Up @@ -284,6 +325,8 @@ impl From<UserFlow> for TokenPurpose {
impl From<SPTFlow> for TokenPurpose {
fn from(value: SPTFlow) -> Self {
match value {
SPTFlow::AuthSelect => Self::AuthSelect,
SPTFlow::SSO => Self::SSO,
SPTFlow::TOTP => Self::TOTP,
SPTFlow::VerifyEmail => Self::VerifyEmail,
SPTFlow::AcceptInvitationFromEmail => Self::AcceptInvitationFromEmail,
Expand Down

0 comments on commit 8ceaaa9

Please sign in to comment.