Skip to content

Commit

Permalink
aws/credentials: Remove unnecessary modification of Expiry CurrentTime (
Browse files Browse the repository at this point in the history
#2094)

Removes the unnecessary setting Expiry's CurrentTime member when
IsExpired method is called. This prevents the possibility of a data race
with the Expiry's IsExpired method.
  • Loading branch information
jasdel committed Aug 8, 2018
1 parent 598ae5f commit 8cf801c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
9 changes: 5 additions & 4 deletions aws/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,14 @@ func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {

// IsExpired returns if the credentials are expired.
func (e *Expiry) IsExpired() bool {
if e.CurrentTime == nil {
e.CurrentTime = time.Now
curTime := e.CurrentTime
if curTime == nil {
curTime = time.Now
}
return e.expiration.Before(e.CurrentTime())
return e.expiration.Before(curTime())
}

// A Credentials provides synchronous safe retrieval of AWS credentials Value.
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
//
Expand Down
26 changes: 26 additions & 0 deletions aws/credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package credentials

import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -62,6 +63,14 @@ func TestCredentialsExpire(t *testing.T) {
assert.True(t, c.IsExpired(), "Expected to be expired")
}

type MockProvider struct {
Expiry
}

func (*MockProvider) Retrieve() (Value, error) {
return Value{}, nil
}

func TestCredentialsGetWithProviderName(t *testing.T) {
stub := &stubProvider{}

Expand All @@ -71,3 +80,20 @@ func TestCredentialsGetWithProviderName(t *testing.T) {
assert.Nil(t, err, "Expected no error")
assert.Equal(t, creds.ProviderName, "stubProvider", "Expected provider name to match")
}

func TestCredentialsIsExpired_Race(t *testing.T) {
creds := NewChainCredentials([]Provider{&MockProvider{}})

starter := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
<-starter
for {
creds.IsExpired()
}
}()
}
close(starter)

time.Sleep(10 * time.Second)
}

0 comments on commit 8cf801c

Please sign in to comment.