diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..397d85bc532 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -3,3 +3,4 @@ ### SDK Enhancements ### SDK Bugs +* `aws/crr`: Fixed a race condition that caused concurrent calls relying on endpoint discovery to share the same `url.URL` reference in their operation's `http.Request`. diff --git a/aws/crr/cache.go b/aws/crr/cache.go index a00ab6c67eb..c07f6731ea7 100644 --- a/aws/crr/cache.go +++ b/aws/crr/cache.go @@ -34,7 +34,10 @@ func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) { return Endpoint{}, false } - c.endpoints.Store(endpointKey, endpoint) + ev := endpoint.(Endpoint) + ev.Prune() + + c.endpoints.Store(endpointKey, ev) return endpoint.(Endpoint), true } diff --git a/aws/crr/cache_test.go b/aws/crr/cache_test.go index 7e1162c5b04..63c57e7c5a3 100644 --- a/aws/crr/cache_test.go +++ b/aws/crr/cache_test.go @@ -4,6 +4,7 @@ import ( "net/url" "reflect" "testing" + "time" ) func urlParse(uri string) *url.URL { @@ -450,3 +451,42 @@ func TestCacheGet(t *testing.T) { } } } + +func TestEndpointCache_Get_prune(t *testing.T) { + c := NewEndpointCache(2) + c.Add(Endpoint{ + Key: "foo", + Addresses: []WeightedAddress{ + { + URL: &url.URL{ + Host: "foo.amazonaws.com", + }, + Expired: time.Now().Add(5 * time.Minute), + }, + { + URL: &url.URL{ + Host: "bar.amazonaws.com", + }, + Expired: time.Now().Add(5 * -time.Minute), + }, + }, + }) + + load, _ := c.endpoints.Load("foo") + if ev := load.(Endpoint); len(ev.Addresses) != 2 { + t.Errorf("expected two weighted addresses") + } + + weightedAddress, err := c.Get(nil, "foo", false) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + if e, a := "foo.amazonaws.com", weightedAddress.URL.Host; e != a { + t.Errorf("expect %v, got %v", e, a) + } + + load, _ = c.endpoints.Load("foo") + if ev := load.(Endpoint); len(ev.Addresses) != 1 { + t.Errorf("expected one weighted address") + } +} diff --git a/aws/crr/endpoint.go b/aws/crr/endpoint.go index d5599188e06..2b088bdbc74 100644 --- a/aws/crr/endpoint.go +++ b/aws/crr/endpoint.go @@ -60,12 +60,32 @@ func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) { continue } + we.URL = cloneURL(we.URL) + return we, true } return WeightedAddress{}, false } +// Prune will prune the expired addresses from the endpoint by allocating a new []WeightAddress. +// This is not concurrent safe, and should be called from a single owning thread. +func (e *Endpoint) Prune() bool { + validLen := e.Len() + if validLen == len(e.Addresses) { + return false + } + wa := make([]WeightedAddress, 0, validLen) + for i := range e.Addresses { + if e.Addresses[i].HasExpired() { + continue + } + wa = append(wa, e.Addresses[i]) + } + e.Addresses = wa + return true +} + // Discoverer is an interface used to discovery which endpoint hit. This // allows for specifics about what parameters need to be used to be contained // in the Discoverer implementor. @@ -97,3 +117,16 @@ func BuildEndpointKey(params map[string]*string) string { return strings.Join(values, ".") } + +func cloneURL(u *url.URL) (clone *url.URL) { + clone = &url.URL{} + + *clone = *u + + if u.User != nil { + user := *u.User + clone.User = &user + } + + return clone +} diff --git a/aws/crr/endpoint_test.go b/aws/crr/endpoint_test.go new file mode 100644 index 00000000000..5e9acbb5892 --- /dev/null +++ b/aws/crr/endpoint_test.go @@ -0,0 +1,126 @@ +//go:build go1.16 +// +build go1.16 + +package crr + +import ( + "net/url" + "reflect" + "strconv" + "testing" + "time" +) + +func Test_cloneURL(t *testing.T) { + tests := []struct { + value *url.URL + wantClone *url.URL + }{ + { + value: &url.URL{ + Scheme: "https", + Opaque: "foo", + User: nil, + Host: "amazonaws.com", + Path: "/", + RawPath: "/", + ForceQuery: true, + RawQuery: "thing=value", + Fragment: "1234", + RawFragment: "1234", + }, + wantClone: &url.URL{ + Scheme: "https", + Opaque: "foo", + User: nil, + Host: "amazonaws.com", + Path: "/", + RawPath: "/", + ForceQuery: true, + RawQuery: "thing=value", + Fragment: "1234", + RawFragment: "1234", + }, + }, + { + value: &url.URL{ + Scheme: "https", + Opaque: "foo", + User: url.UserPassword("NOT", "VALID"), + Host: "amazonaws.com", + Path: "/", + RawPath: "/", + ForceQuery: true, + RawQuery: "thing=value", + Fragment: "1234", + RawFragment: "1234", + }, + wantClone: &url.URL{ + Scheme: "https", + Opaque: "foo", + User: url.UserPassword("NOT", "VALID"), + Host: "amazonaws.com", + Path: "/", + RawPath: "/", + ForceQuery: true, + RawQuery: "thing=value", + Fragment: "1234", + RawFragment: "1234", + }, + }, + } + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + gotClone := cloneURL(tt.value) + if gotClone == tt.value { + t.Errorf("expct clone URL to not be same pointer address") + } + if tt.value.User != nil { + if tt.value.User == gotClone.User { + t.Errorf("expct cloned Userinfo to not be same pointer address") + } + } + if !reflect.DeepEqual(gotClone, tt.wantClone) { + t.Errorf("cloneURL() = %v, want %v", gotClone, tt.wantClone) + } + }) + } +} + +func TestEndpoint_Prune(t *testing.T) { + endpoint := Endpoint{} + + endpoint.Add(WeightedAddress{ + URL: &url.URL{}, + Expired: time.Now().Add(5 * time.Minute), + }) + + initial := endpoint.Addresses + + if e, a := false, endpoint.Prune(); e != a { + t.Errorf("expect prune %v, got %v", e, a) + } + + if e, a := &initial[0], &endpoint.Addresses[0]; e != a { + t.Errorf("expect slice address to be same") + } + + endpoint.Add(WeightedAddress{ + URL: &url.URL{}, + Expired: time.Now().Add(5 * -time.Minute), + }) + + initial = endpoint.Addresses + + if e, a := true, endpoint.Prune(); e != a { + t.Errorf("expect prune %v, got %v", e, a) + } + + if e, a := &initial[0], &endpoint.Addresses[0]; e == a { + t.Errorf("expect slice address to be different") + } + + if e, a := 1, endpoint.Len(); e != a { + t.Errorf("expect slice length %v, got %v", e, a) + } +}