Skip to content
This repository has been archived by the owner on May 18, 2021. It is now read-only.

fix: invalid request when assume_role_ttl >1h used from upstream source_profile #215

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmd/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ func loadDurationFlagFromEnv(cmd *cobra.Command, flagName string, envVar string,
}

func updateDurationFromConfigProfile(profiles lib.Profiles, profile string, val *time.Duration) error {
fromProfile, _, err := profiles.GetValue(profile, "assume_role_ttl")
// When role chaining, AWS sets a hard 1h limit on the assume role TTL.
// So we require this value to be set on the profile directly.
// See: https://github.com/awsdocs/iam-user-guide/blob/8d78057/doc_source/id_roles_terms-and-concepts.md
fromProfile, _, err := profiles.GetValue(profile, "assume_role_ttl", false)
if err != nil {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func updateMfaConfig(cmd *cobra.Command, profiles lib.Profiles, profile string,
if ok {
config.Provider = mfaProvider
} else {
mfaProvider, _, err := profiles.GetValue(profile, "mfa_provider")
mfaProvider, _, err := profiles.GetValue(profile, "mfa_provider", true)
if err == nil {
config.Provider = mfaProvider
}
Expand All @@ -156,7 +156,7 @@ func updateMfaConfig(cmd *cobra.Command, profiles lib.Profiles, profile string,
if ok {
config.FactorType = mfaFactorType
} else {
mfaFactorType, _, err := profiles.GetValue(profile, "mfa_factor_type")
mfaFactorType, _, err := profiles.GetValue(profile, "mfa_factor_type", true)
if err == nil {
config.FactorType = mfaFactorType
}
Expand Down
6 changes: 5 additions & 1 deletion lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ func sourceProfile(p string, from Profiles) string {
return p
}

func (p Profiles) GetValue(profile string, config_key string) (string, string, error) {
func (p Profiles) GetValue(profile string, config_key string, recursive bool) (string, string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started to make this change, but ended up creating a GetDirectValue function instead. Here's why:

I was calling the existing (before this PR) GetValue behavior recursive. Upon revisiting this, it occurred to me that it was a bit misleading because it doesn't descend past the immediate source_profile.

To address the underlying issue, I needed a function that only looked at settings set directly on the given profile, which is what GetDirectValue does. My most recent commit restores the original GetValue behavior, but I did make it use GetDirectValue under-the-hood to DRY things up a little. Happy to restore the old implementation if you'd prefer less indirection over DRY here.

config_value, ok := p[profile][config_key]
if ok {
return config_value, profile, nil
}

if !recursive {
return "", "", fmt.Errorf("Could not find %s in %s", config_key, profile)
}

// Lookup from the `source_profile`, if it exists
profile, ok = p[profile]["source_profile"]
if ok {
Expand Down
70 changes: 57 additions & 13 deletions lib/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@ import "testing"
func TestGetConfigValue(t *testing.T) {
config_profiles := make(Profiles)

t.Run("empty profile", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key")
t.Run("empty profile recursive search", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key", true)
if found_error == nil {
t.Error("Searching an empty profile set should return an error")
t.Error("Recursive search of an empty profile set should return an error")
}
})

t.Run("empty profile non-recursive search", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key", false)
if found_error == nil {
t.Error("Non-recursive search of an empty profile set should return an error")
}
})

Expand All @@ -34,17 +41,24 @@ func TestGetConfigValue(t *testing.T) {
"key_f": "f-c",
}

t.Run("missing key", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key")
t.Run("missing key recursive search", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key", true)
if found_error == nil {
t.Error("Recursive search for a missing key should return an error")
}
})

t.Run("missing key non-recursive search", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key", false)
if found_error == nil {
t.Error("Searching for a missing key should return an error")
t.Error("Non-recursive search for a missing key should return an error")
}
})

t.Run("fallback to okta", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_a", "key_a")
t.Run("fallback to okta on recursive search", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_a", "key_a", true)
if found_error != nil {
t.Error("Error when searching for key_a")
t.Error("Error when recursively searching for key_a")
}

if found_profile != "okta" {
Expand All @@ -56,8 +70,38 @@ func TestGetConfigValue(t *testing.T) {
}
})

t.Run("found in current profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_d")
t.Run("no fallback to okta on non-recursive search", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_a", "key_a", false)
if found_error == nil {
t.Error("Non-recursive search for key missing from top-level should return an error")
}

if found_profile != "" {
t.Error("key_a should not have been found in any profile")
}

if found_value != "" {
t.Error("No value should have been found for `key_a`")
}
})

t.Run("recursive search for item found in current profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_d", true)
if found_error != nil {
t.Error("Error when searching for key_d")
}

if found_profile != "profile_b" {
t.Error("key_d should have come from `profile_b`")
}

if found_value != "d-b" {
t.Error("The proper value for `key_d` should be `d-b`")
}
})

t.Run("non-recursive search for item found in current profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_d", false)
if found_error != nil {
t.Error("Error when searching for key_d")
}
Expand All @@ -72,7 +116,7 @@ func TestGetConfigValue(t *testing.T) {
})

t.Run("traversing from child profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_a")
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_a", true)
if found_error != nil {
t.Error("Error when searching for key_a")
}
Expand All @@ -87,7 +131,7 @@ func TestGetConfigValue(t *testing.T) {
})

t.Run("recursive traversing from child profile", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_c", "key_c")
_, _, found_error := config_profiles.GetValue("profile_c", "key_c", true)
if found_error == nil {
t.Error("Recursive searching should not work")
}
Expand Down
4 changes: 2 additions & 2 deletions lib/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (p *Provider) Retrieve() (credentials.Value, error) {
}

func (p *Provider) getSamlURL() (string, error) {
oktaAwsSAMLUrl, profile, err := p.profiles.GetValue(p.profile, "aws_saml_url")
oktaAwsSAMLUrl, profile, err := p.profiles.GetValue(p.profile, "aws_saml_url", true)
if err != nil {
return "", errors.New("aws_saml_url missing from ~/.aws/config")
}
Expand All @@ -192,7 +192,7 @@ func (p *Provider) getSamlURL() (string, error) {
}

func (p *Provider) getOktaSessionCookieKey() string {
oktaSessionCookieKey, profile, err := p.profiles.GetValue(p.profile, "okta_session_cookie_key")
oktaSessionCookieKey, profile, err := p.profiles.GetValue(p.profile, "okta_session_cookie_key", true)
if err != nil {
return "okta-session-cookie"
}
Expand Down