diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..6e17bd307 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,2 @@ +# Add core contributors to all PRs by default +* @aws/aws-sdk-go-team diff --git a/.github/workflows/api_diff_check.yml b/.github/workflows/api_diff_check.yml index ccda47c39..1d57a52e8 100644 --- a/.github/workflows/api_diff_check.yml +++ b/.github/workflows/api_diff_check.yml @@ -24,7 +24,7 @@ jobs: - name: Get dependencies run: | - (cd /tmp && go get golang.org/x/exp/cmd/gorelease@latest) + (cd /tmp && go install golang.org/x/exp/cmd/gorelease@latest) - name: Check APIs run: $(go env GOPATH)/bin/gorelease diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 08a0c3abf..1660bc15c 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,12 +8,12 @@ on: jobs: unit-tests: - name: Test SDK + name: SDK Unit Tests runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go-version: [1.17, 1.16, 1.15] + go-version: [1.18, 1.17] steps: - uses: actions/checkout@v2 @@ -25,3 +25,21 @@ jobs: - name: Test run: go test -v ./... + deprecated-unit-tests: + needs: unit-tests + name: Deprecated Go version SDK Unit Tests + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + go-version: [1.16, 1.15] + steps: + - uses: actions/checkout@v2 + + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + + - name: Test + run: go test -v ./... diff --git a/CHANGELOG.md b/CHANGELOG.md index 733f31753..a608e2b63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,49 @@ +# Release (2022-09-14) + +* No change notes available for this release. + +# Release (v1.13.2) + +* No change notes available for this release. + +# Release (v1.13.1) + +* No change notes available for this release. + +# Release (v1.13.0) + +## Module Highlights +* `github.com/aws/smithy-go`: v1.13.0 + * **Feature**: Adds support for the Smithy httpBearerAuth authentication trait to smithy-go. This allows the SDK to support the bearer authentication flow for API operations decorated with httpBearerAuth. An API client will need to be provided with its own bearer.TokenProvider implementation or use the bearer.StaticTokenProvider implementation. + +# Release (v1.12.1) + +## Module Highlights +* `github.com/aws/smithy-go`: v1.12.1 + * **Bug Fix**: Fixes a bug where JSON object keys were not escaped. + +# Release (v1.12.0) + +## Module Highlights +* `github.com/aws/smithy-go`: v1.12.0 + * **Feature**: `transport/http`: Add utility for setting context metadata when operation serializer automatically assigns content-type default value. + +# Release (v1.11.3) + +## Module Highlights +* `github.com/aws/smithy-go`: v1.11.3 + * **Dependency Update**: Updates smithy-go unit test dependency go-cmp to 0.5.8. + +# Release (v1.11.2) + +* No change notes available for this release. + +# Release (v1.11.1) + +## Module Highlights +* `github.com/aws/smithy-go`: v1.11.1 + * **Bug Fix**: Updates the smithy-go HTTP Request to correctly handle building the request to an http.Request. Related to [aws/aws-sdk-go-v2#1583](https://github.com/aws/aws-sdk-go-v2/issues/1583) + # Release (v1.11.0) ## Module Highlights diff --git a/auth/bearer/docs.go b/auth/bearer/docs.go new file mode 100644 index 000000000..1c9b9715c --- /dev/null +++ b/auth/bearer/docs.go @@ -0,0 +1,3 @@ +// Package bearer provides middleware and utilities for authenticating API +// operation calls with a Bearer Token. +package bearer diff --git a/auth/bearer/middleware.go b/auth/bearer/middleware.go new file mode 100644 index 000000000..8c7d72099 --- /dev/null +++ b/auth/bearer/middleware.go @@ -0,0 +1,104 @@ +package bearer + +import ( + "context" + "fmt" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// Message is the middleware stack's request transport message value. +type Message interface{} + +// Signer provides an interface for implementations to decorate a request +// message with a bearer token. The signer is responsible for validating the +// message type is compatible with the signer. +type Signer interface { + SignWithBearerToken(context.Context, Token, Message) (Message, error) +} + +// AuthenticationMiddleware provides the Finalize middleware step for signing +// an request message with a bearer token. +type AuthenticationMiddleware struct { + signer Signer + tokenProvider TokenProvider +} + +// AddAuthenticationMiddleware helper adds the AuthenticationMiddleware to the +// middleware Stack in the Finalize step with the options provided. +func AddAuthenticationMiddleware(s *middleware.Stack, signer Signer, tokenProvider TokenProvider) error { + return s.Finalize.Add( + NewAuthenticationMiddleware(signer, tokenProvider), + middleware.After, + ) +} + +// NewAuthenticationMiddleware returns an initialized AuthenticationMiddleware. +func NewAuthenticationMiddleware(signer Signer, tokenProvider TokenProvider) *AuthenticationMiddleware { + return &AuthenticationMiddleware{ + signer: signer, + tokenProvider: tokenProvider, + } +} + +const authenticationMiddlewareID = "BearerTokenAuthentication" + +// ID returns the resolver identifier +func (m *AuthenticationMiddleware) ID() string { + return authenticationMiddlewareID +} + +// HandleFinalize implements the FinalizeMiddleware interface in order to +// update the request with bearer token authentication. +func (m *AuthenticationMiddleware) HandleFinalize( + ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler, +) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + token, err := m.tokenProvider.RetrieveBearerToken(ctx) + if err != nil { + return out, metadata, fmt.Errorf("failed AuthenticationMiddleware wrap message, %w", err) + } + + signedMessage, err := m.signer.SignWithBearerToken(ctx, token, in.Request) + if err != nil { + return out, metadata, fmt.Errorf("failed AuthenticationMiddleware sign message, %w", err) + } + + in.Request = signedMessage + return next.HandleFinalize(ctx, in) +} + +// SignHTTPSMessage provides a bearer token authentication implementation that +// will sign the message with the provided bearer token. +// +// Will fail if the message is not a smithy-go HTTP request or the request is +// not HTTPS. +type SignHTTPSMessage struct{} + +// NewSignHTTPSMessage returns an initialized signer for HTTP messages. +func NewSignHTTPSMessage() *SignHTTPSMessage { + return &SignHTTPSMessage{} +} + +// SignWithBearerToken returns a copy of the HTTP request with the bearer token +// added via the "Authorization" header, per RFC 6750, https://datatracker.ietf.org/doc/html/rfc6750. +// +// Returns an error if the request's URL scheme is not HTTPS, or the request +// message is not an smithy-go HTTP Request pointer type. +func (SignHTTPSMessage) SignWithBearerToken(ctx context.Context, token Token, message Message) (Message, error) { + req, ok := message.(*smithyhttp.Request) + if !ok { + return nil, fmt.Errorf("expect smithy-go HTTP Request, got %T", message) + } + + if !req.IsHTTPS() { + return nil, fmt.Errorf("bearer token with HTTP request requires HTTPS") + } + + reqClone := req.Clone() + reqClone.Header.Set("Authorization", "Bearer "+token.Value) + + return reqClone, nil +} diff --git a/auth/bearer/middleware_test.go b/auth/bearer/middleware_test.go new file mode 100644 index 000000000..e9604f089 --- /dev/null +++ b/auth/bearer/middleware_test.go @@ -0,0 +1,78 @@ +package bearer + +import ( + "context" + "net/http" + "net/url" + "strings" + "testing" + + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestSignHTTPSMessage(t *testing.T) { + cases := map[string]struct { + message Message + token Token + expectMessage Message + expectErr string + }{ + // Cases + "not smithyhttp.Request": { + message: struct{}{}, + expectErr: "expect smithy-go HTTP Request", + }, + "not https": { + message: func() Message { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + return r + }(), + expectErr: "requires HTTPS", + }, + "success": { + message: func() Message { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + return r + }(), + token: Token{Value: "abc123"}, + expectMessage: func() Message { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.Header.Set("Authorization", "Bearer abc123") + return r + }(), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + signer := SignHTTPSMessage{} + message, err := signer.SignWithBearerToken(ctx, c.token, c.message) + if c.expectErr != "" { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %v in error %v", e, a) + } + return + } else if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + options := []cmp.Option{ + cmpopts.IgnoreUnexported(smithyhttp.Request{}), + cmpopts.IgnoreUnexported(http.Request{}), + } + + if diff := cmp.Diff(c.expectMessage, message, options...); diff != "" { + t.Errorf("expect match\n%s", diff) + } + }) + } +} diff --git a/auth/bearer/token.go b/auth/bearer/token.go new file mode 100644 index 000000000..be260d4c7 --- /dev/null +++ b/auth/bearer/token.go @@ -0,0 +1,50 @@ +package bearer + +import ( + "context" + "time" +) + +// Token provides a type wrapping a bearer token and expiration metadata. +type Token struct { + Value string + + CanExpire bool + Expires time.Time +} + +// Expired returns if the token's Expires time is before or equal to the time +// provided. If CanExpires is false, Expired will always return false. +func (t Token) Expired(now time.Time) bool { + if !t.CanExpire { + return false + } + now = now.Round(0) + return now.Equal(t.Expires) || now.After(t.Expires) +} + +// TokenProvider provides interface for retrieving bearer tokens. +type TokenProvider interface { + RetrieveBearerToken(context.Context) (Token, error) +} + +// TokenProviderFunc provides a helper utility to wrap a function as a type +// that implements the TokenProvider interface. +type TokenProviderFunc func(context.Context) (Token, error) + +// RetrieveBearerToken calls the wrapped function, returning the Token or +// error. +func (fn TokenProviderFunc) RetrieveBearerToken(ctx context.Context) (Token, error) { + return fn(ctx) +} + +// StaticTokenProvider provides a utility for wrapping a static bearer token +// value within an implementation of a token provider. +type StaticTokenProvider struct { + Token Token +} + +// RetrieveBearerToken returns the static token specified. +func (s StaticTokenProvider) RetrieveBearerToken(context.Context) (Token, error) { + return s.Token, nil +} diff --git a/auth/bearer/token_cache.go b/auth/bearer/token_cache.go new file mode 100644 index 000000000..223ddf52b --- /dev/null +++ b/auth/bearer/token_cache.go @@ -0,0 +1,208 @@ +package bearer + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + smithycontext "github.com/aws/smithy-go/context" + "github.com/aws/smithy-go/internal/sync/singleflight" +) + +// package variable that can be override in unit tests. +var timeNow = time.Now + +// TokenCacheOptions provides a set of optional configuration options for the +// TokenCache TokenProvider. +type TokenCacheOptions struct { + // The duration before the token will expire when the credentials will be + // refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls + // will be blocking. + // + // Asynchronous refreshes are deduplicated, and only one will be in-flight + // at a time. If the token expires while an asynchronous refresh is in + // flight, the next call to RetrieveBearerToken will block on that refresh + // to return. + RefreshBeforeExpires time.Duration + + // The timeout the underlying TokenProvider's RetrieveBearerToken call must + // return within, or will be canceled. Defaults to 0, no timeout. + // + // If 0 timeout, its possible for the underlying tokenProvider's + // RetrieveBearerToken call to block forever. Preventing subsequent + // TokenCache attempts to refresh the token. + // + // If this timeout is reached all pending deduplicated calls to + // TokenCache RetrieveBearerToken will fail with an error. + RetrieveBearerTokenTimeout time.Duration + + // The minimum duration between asynchronous refresh attempts. If the next + // asynchronous recent refresh attempt was within the minimum delay + // duration, the call to retrieve will return the current cached token, if + // not expired. + // + // The asynchronous retrieve is deduplicated across multiple calls when + // RetrieveBearerToken is called. The asynchronous retrieve is not a + // periodic task. It is only performed when the token has not yet expired, + // and the current item is within the RefreshBeforeExpires window, and the + // TokenCache's RetrieveBearerToken method is called. + // + // If 0, (default) there will be no minimum delay between asynchronous + // refresh attempts. + // + // If DisableAsyncRefresh is true, this option is ignored. + AsyncRefreshMinimumDelay time.Duration + + // Sets if the TokenCache will attempt to refresh the token in the + // background asynchronously instead of blocking for credentials to be + // refreshed. If disabled token refresh will be blocking. + // + // The first call to RetrieveBearerToken will always be blocking, because + // there is no cached token. + DisableAsyncRefresh bool +} + +// TokenCache provides an utility to cache Bearer Authentication tokens from a +// wrapped TokenProvider. The TokenCache can be has options to configure the +// cache's early and asynchronous refresh of the token. +type TokenCache struct { + options TokenCacheOptions + provider TokenProvider + + cachedToken atomic.Value + lastRefreshAttemptTime atomic.Value + sfGroup singleflight.Group +} + +// NewTokenCache returns a initialized TokenCache that implements the +// TokenProvider interface. Wrapping the provider passed in. Also taking a set +// of optional functional option parameters to configure the token cache. +func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache { + var options TokenCacheOptions + for _, fn := range optFns { + fn(&options) + } + + return &TokenCache{ + options: options, + provider: provider, + } +} + +// RetrieveBearerToken returns the token if it could be obtained, or error if a +// valid token could not be retrieved. +// +// The passed in Context's cancel/deadline/timeout will impacting only this +// individual retrieve call and not any other already queued up calls. This +// means underlying provider's RetrieveBearerToken calls could block for ever, +// and not be canceled with the Context. Set RetrieveBearerTokenTimeout to +// provide a timeout, preventing the underlying TokenProvider blocking forever. +// +// By default, if the passed in Context is canceled, all of its values will be +// considered expired. The wrapped TokenProvider will not be able to lookup the +// values from the Context once it is expired. This is done to protect against +// expired values no longer being valid. To disable this behavior, use +// smithy-go's context.WithPreserveExpiredValues to add a value to the Context +// before calling RetrieveBearerToken to enable support for expired values. +// +// Without RetrieveBearerTokenTimeout there is the potential for a underlying +// Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent +// attempts at refreshing the token. +func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) { + cachedToken, ok := p.getCachedToken() + if !ok || cachedToken.Expired(timeNow()) { + return p.refreshBearerToken(ctx) + } + + // Check if the token should be refreshed before it expires. + refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires)) + if !refreshToken { + return cachedToken, nil + } + + if p.options.DisableAsyncRefresh { + return p.refreshBearerToken(ctx) + } + + p.tryAsyncRefresh(ctx) + + return cachedToken, nil +} + +// tryAsyncRefresh attempts to asynchronously refresh the token returning the +// already cached token. If it AsyncRefreshMinimumDelay option is not zero, and +// the duration since the last refresh is less than that value, nothing will be +// done. +func (p *TokenCache) tryAsyncRefresh(ctx context.Context) { + if p.options.AsyncRefreshMinimumDelay != 0 { + var lastRefreshAttempt time.Time + if v := p.lastRefreshAttemptTime.Load(); v != nil { + lastRefreshAttempt = v.(time.Time) + } + + if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) { + return + } + } + + // Ignore the returned channel so this won't be blocking, and limit the + // number of additional goroutines created. + p.sfGroup.DoChan("async-refresh", func() (interface{}, error) { + res, err := p.refreshBearerToken(ctx) + if p.options.AsyncRefreshMinimumDelay != 0 { + var refreshAttempt time.Time + if err != nil { + refreshAttempt = timeNow() + } + p.lastRefreshAttemptTime.Store(refreshAttempt) + } + + return res, err + }) +} + +func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) { + resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) { + ctx := smithycontext.WithSuppressCancel(ctx) + if v := p.options.RetrieveBearerTokenTimeout; v != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, v) + defer cancel() + } + return p.singleRetrieve(ctx) + }) + + select { + case res := <-resCh: + return res.Val.(Token), res.Err + case <-ctx.Done(): + return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err()) + } +} + +func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) { + token, err := p.provider.RetrieveBearerToken(ctx) + if err != nil { + return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err) + } + + p.cachedToken.Store(&token) + return token, nil +} + +// getCachedToken returns the currently cached token and true if found. Returns +// false if no token is cached. +func (p *TokenCache) getCachedToken() (Token, bool) { + v := p.cachedToken.Load() + if v == nil { + return Token{}, false + } + + t := v.(*Token) + if t == nil || t.Value == "" { + return Token{}, false + } + + return *t, true +} diff --git a/auth/bearer/token_cache_test.go b/auth/bearer/token_cache_test.go new file mode 100644 index 000000000..3d56f7ee6 --- /dev/null +++ b/auth/bearer/token_cache_test.go @@ -0,0 +1,512 @@ +package bearer + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +var _ TokenProvider = (*TokenCache)(nil) + +func TestTokenCache_cache(t *testing.T) { + expectToken := Token{ + Value: "abc123", + } + + var retrieveCalled bool + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + if retrieveCalled { + t.Fatalf("expect wrapped provider to be called once") + } + retrieveCalled = true + return expectToken, nil + })) + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + for i := 0; i < 100; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + } +} + +func TestTokenCache_cacheConcurrent(t *testing.T) { + expectToken := Token{ + Value: "abc123", + } + + var retrieveCalled bool + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + if retrieveCalled { + t.Fatalf("expect wrapped provider to be called once") + } + retrieveCalled = true + return expectToken, nil + })) + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + for i := 0; i < 100; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + }) + } +} + +func TestTokenCache_expired(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + if atomic.AddInt32(retrievedCount, 1) > 1 { + return refreshedToken, nil + } + return expectToken, nil + })) + + for i := 0; i < 10; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + } + if e, a := 1, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Errorf("expect %v provider calls, got %v", e, a) + } + + // Offset time for refresh + timeNow = func() time.Time { + return (time.Time{}).Add(10 * time.Minute) + } + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } + if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Errorf("expect %v provider calls, got %v", e, a) + } +} + +func TestTokenCache_cancelled(t *testing.T) { + providerRunning := make(chan struct{}) + providerDone := make(chan struct{}) + var onceClose sync.Once + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + onceClose.Do(func() { close(providerRunning) }) + + // Provider running never receives context cancel so that if the first + // retrieve call is canceled all subsequent retrieve callers won't get + // canceled as well. + select { + case <-providerDone: + return Token{Value: "abc123"}, nil + case <-ctx.Done(): + return Token{}, fmt.Errorf("unexpected context canceled, %w", ctx.Err()) + } + })) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Retrieve that will have its context canceled, should return error, but + // underlying provider retrieve will continue to block in the background. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + _, err := provider.RetrieveBearerToken(ctx) + if err == nil { + t.Errorf("expect error, got none") + + } else if e, a := "unexpected context canceled", err.Error(); strings.Contains(a, e) { + t.Errorf("unexpected context canceled received, %v", err) + + } else if e, a := "context canceled", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v error in, %v", e, a) + } + }() + + <-providerRunning + + // Retrieve that will be added to existing single flight group, (or create + // a new group). Returning valid token. + wg.Add(1) + go func() { + defer wg.Done() + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Errorf("expect no error, got %v", err) + } else { + if diff := cmp.Diff(Token{Value: "abc123"}, token); diff != "" { + t.Errorf("expect token retrieve match\n%s", diff) + } + } + }() + close(providerDone) + + wg.Wait() +} + +func TestTokenCache_cancelledWithTimeout(t *testing.T) { + providerReady := make(chan struct{}) + var providerReadCloseOnce sync.Once + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + providerReadCloseOnce.Do(func() { close(providerReady) }) + + <-ctx.Done() + return Token{}, fmt.Errorf("token retrieve timeout, %w", ctx.Err()) + }), func(o *TokenCacheOptions) { + o.RetrieveBearerTokenTimeout = time.Millisecond + }) + + var wg sync.WaitGroup + + // Spin up additional retrieves that will be deduplicated and block on the + // original retrieve call. + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-providerReady + + _, err := provider.RetrieveBearerToken(context.Background()) + if err == nil { + t.Errorf("expect error, got none") + + } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v error in, %v", e, a) + } + }() + } + + _, err := provider.RetrieveBearerToken(context.Background()) + if err == nil { + t.Errorf("expect error, got none") + + } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v error in, %v", e, a) + } + + wg.Wait() +} + +func TestTokenCache_asyncRefresh(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + c := atomic.AddInt32(retrievedCount, 1) + switch { + case c == 1: + return expectToken, nil + case c > 1 && c < 5: + return Token{}, fmt.Errorf("some error") + case c == 5: + return refreshedToken, nil + default: + return Token{}, fmt.Errorf("unexpected error") + } + }), func(o *TokenCacheOptions) { + o.RefreshBeforeExpires = 5 * time.Minute + }) + + // 1: Initial retrieve to cache token + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous + // refreshes. + timeNow = func() time.Time { + return (time.Time{}).Add(6 * time.Minute) + } + + for i := 0; i < 4; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + } + // Wait for all async refreshes to complete + testWaitAsyncRefreshDone(provider) + + if c := int(atomic.LoadInt32(retrievedCount)); c < 2 || c > 5 { + t.Fatalf("expect async refresh to be called [2,5) times, got, %v", c) + } + + // Ensure enough retrieves have been done to trigger refresh. + if c := atomic.LoadInt32(retrievedCount); c != 5 { + atomic.StoreInt32(retrievedCount, 4) + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + testWaitAsyncRefreshDone(provider) + } + + // Last async refresh will succeed and update cached token, expect the next + // call to get refreshed token. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } +} + +func TestTokenCache_asyncRefreshWithMinDelay(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + c := atomic.AddInt32(retrievedCount, 1) + switch { + case c == 1: + return expectToken, nil + case c > 1 && c < 5: + return Token{}, fmt.Errorf("some error") + case c == 5: + return refreshedToken, nil + default: + return Token{}, fmt.Errorf("unexpected error") + } + }), func(o *TokenCacheOptions) { + o.RefreshBeforeExpires = 5 * time.Minute + o.AsyncRefreshMinimumDelay = 30 * time.Second + }) + + // 1: Initial retrieve to cache token + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous + // refreshes. + timeNow = func() time.Time { + return (time.Time{}).Add(6 * time.Minute) + } + + for i := 0; i < 4; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + // Wait for all async refreshes to complete ensure not deduped + testWaitAsyncRefreshDone(provider) + } + + // Only a single refresh attempt is expected. + if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Fatalf("expect %v min async refresh, got %v", e, a) + } + + // Move time forward to ensure another async refresh is triggered. + timeNow = func() time.Time { return (time.Time{}).Add(7 * time.Minute) } + // Make sure the next attempt refreshes the token + atomic.StoreInt32(retrievedCount, 4) + + // Do async retrieve that will succeed refreshing in background. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + // Wait for all async refreshes to complete ensure not deduped + testWaitAsyncRefreshDone(provider) + + // Last async refresh will succeed and update cached token, expect the next + // call to get refreshed token. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } +} + +func TestTokenCache_disableAsyncRefresh(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + c := atomic.AddInt32(retrievedCount, 1) + switch { + case c == 1: + return expectToken, nil + case c > 1 && c < 5: + return Token{}, fmt.Errorf("some error") + case c == 5: + return refreshedToken, nil + default: + return Token{}, fmt.Errorf("unexpected error") + } + }), func(o *TokenCacheOptions) { + o.RefreshBeforeExpires = 5 * time.Minute + o.DisableAsyncRefresh = true + }) + + // 1: Initial retrieve to cache token + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + // Update time into refresh window before token expires + timeNow = func() time.Time { + return (time.Time{}).Add(6 * time.Minute) + } + + for i := 0; i < 3; i++ { + _, err = provider.RetrieveBearerToken(context.Background()) + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := "some error", err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %v error in %v", e, a) + } + if e, a := i+2, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Fatalf("expect %v retrieveCount, got %v", e, a) + } + } + if e, a := 4, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Fatalf("expect %v retrieveCount, got %v", e, a) + } + + // Last refresh will succeed and update cached token, expect the next + // call to get refreshed token. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } +} + +func testWaitAsyncRefreshDone(provider *TokenCache) { + asyncResCh := provider.sfGroup.DoChan("async-refresh", func() (interface{}, error) { + return nil, nil + }) + <-asyncResCh +} diff --git a/codegen/go.mod b/codegen/go.mod new file mode 100644 index 000000000..54d7cb2a9 --- /dev/null +++ b/codegen/go.mod @@ -0,0 +1,3 @@ +module github.com/aws/smithy-go/codegen + +go 1.15 diff --git a/codegen/gradle.properties b/codegen/gradle.properties index eb24dc2cb..7ab719891 100644 --- a/codegen/gradle.properties +++ b/codegen/gradle.properties @@ -1 +1,2 @@ -smithyVersion=1.17.0 +smithyVersion=1.25.0 +smithyGradleVersion=0.6.0 diff --git a/codegen/settings.gradle.kts b/codegen/settings.gradle.kts index 3ba88bda2..d62e52310 100644 --- a/codegen/settings.gradle.kts +++ b/codegen/settings.gradle.kts @@ -20,6 +20,7 @@ include(":smithy-go-codegen-test") pluginManagement { repositories { mavenLocal() + mavenCentral() gradlePluginPortal() } } diff --git a/codegen/smithy-go-codegen-test/build.gradle.kts b/codegen/smithy-go-codegen-test/build.gradle.kts index 74c5528e9..5b8a4f29b 100644 --- a/codegen/smithy-go-codegen-test/build.gradle.kts +++ b/codegen/smithy-go-codegen-test/build.gradle.kts @@ -23,6 +23,7 @@ tasks["jar"].enabled = false buildscript { val smithyVersion: String by project repositories { + mavenLocal() mavenCentral() } dependencies { @@ -31,7 +32,8 @@ buildscript { } plugins { - id("software.amazon.smithy").version("0.5.3") + val smithyGradleVersion: String by project + id("software.amazon.smithy") version smithyGradleVersion } repositories { diff --git a/codegen/smithy-go-codegen-test/model/main.smithy b/codegen/smithy-go-codegen-test/model/main.smithy index d9cce83c8..d41e76f8b 100644 --- a/codegen/smithy-go-codegen-test/model/main.smithy +++ b/codegen/smithy-go-codegen-test/model/main.smithy @@ -1,4 +1,4 @@ -$version: "1.0" +$version: "2.0" namespace example.weather use smithy.test#httpRequestTests @@ -6,6 +6,7 @@ use smithy.test#httpResponseTests use smithy.waiters#waitable /// Provides weather forecasts. +@httpBearerAuth @fakeProtocol @paginated(inputToken: "nextToken", outputToken: "nextToken", pageSize: "pageSize") service Weather { @@ -189,7 +190,7 @@ structure GetCityOutput { // This structure is nested within GetCityOutput. structure CityCoordinates { @required - latitude: PrimitiveFloat, + latitude: Float, @required longitude: Float, @@ -309,15 +310,24 @@ structure ListCitiesInput { pageSize: Integer } -structure ListCitiesOutput { - nextToken: String, +intEnum SimpleOneZero { + ONE = 1 + ZERO = 0 +} +@mixin +structure ListCitiesMixin { someEnum: SimpleYesNo, aString: String, defaultBool: DefaultBool, boxedBool: Boolean, defaultNumber: DefaultInteger, boxedNumber: Integer, + someIntegerEnum: SimpleOneZero +} + +structure ListCitiesOutput with [ListCitiesMixin] { + nextToken: String, @required items: CitySummaries, @@ -380,8 +390,8 @@ structure GetForecastOutput { } union Precipitation { - rain: PrimitiveBoolean, - sleet: PrimitiveBoolean, + rain: Boolean, + sleet: Boolean, hail: StringMap, snow: SimpleYesNo, mixed: TypedYesNo, @@ -393,11 +403,15 @@ union Precipitation { structure OtherStructure {} -@enum([{value: "YES"}, {value: "NO"}]) -string SimpleYesNo +enum SimpleYesNo { + YES + NO +} -@enum([{value: "YES", name: "YES"}, {value: "NO", name: "NO"}]) -string TypedYesNo +enum TypedYesNo { + YES = "YES" + NO = "NO" +} map StringMap { key: String, @@ -435,7 +449,7 @@ structure PNGImage { structure GetCityImageOutput { @httpPayload - image: CityImageData, + image: CityImageData = "", } @streaming diff --git a/codegen/smithy-go-codegen-test/model/more-nesting.smithy b/codegen/smithy-go-codegen-test/model/more-nesting.smithy index cbdaac895..c54a996c1 100644 --- a/codegen/smithy-go-codegen-test/model/more-nesting.smithy +++ b/codegen/smithy-go-codegen-test/model/more-nesting.smithy @@ -1,4 +1,4 @@ -$version: "1" +$version: "2" namespace example.weather.nested.more diff --git a/codegen/smithy-go-codegen-test/model/nested.smithy b/codegen/smithy-go-codegen-test/model/nested.smithy index b50b30a05..c584dacb6 100644 --- a/codegen/smithy-go-codegen-test/model/nested.smithy +++ b/codegen/smithy-go-codegen-test/model/nested.smithy @@ -1,4 +1,4 @@ -$version: "1" +$version: "2" namespace example.weather.nested diff --git a/codegen/smithy-go-codegen-test/smithy-build.json b/codegen/smithy-go-codegen-test/smithy-build.json index e8f2eddef..1330bacdc 100644 --- a/codegen/smithy-go-codegen-test/smithy-build.json +++ b/codegen/smithy-go-codegen-test/smithy-build.json @@ -3,7 +3,7 @@ "plugins": { "go-codegen": { "service": "example.weather#Weather", - "module": "weather", + "module": "github.com/aws/smithy-go/internal/tests/service/weather", "moduleVersion": "0.0.1" } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenVisitor.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenVisitor.java index 6650443ce..d0316db10 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenVisitor.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenVisitor.java @@ -39,6 +39,7 @@ import software.amazon.smithy.model.knowledge.ServiceIndex; import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.neighbor.Walker; +import software.amazon.smithy.model.shapes.IntEnumShape; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.Shape; @@ -90,6 +91,12 @@ final class CodegenVisitor extends ShapeVisitor.Default { var modelTransformer = ModelTransformer.create(); + /* + smithy 1.23.0 added support for mixins. This transform flattens and applies the mixins + and remove them from the model + */ + resolvedModel = modelTransformer.flattenAndRemoveMixins(resolvedModel); + // Add unique operation input/output shapes resolvedModel = AddOperationShapes.execute(resolvedModel, settings.getService()); @@ -136,7 +143,7 @@ final class CodegenVisitor extends ShapeVisitor.Default { ? ApplicationProtocol.createDefaultHttpApplicationProtocol() : protocolGenerator.getApplicationProtocol(); - writers = new GoDelegator(settings, model, fileManifest, symbolProvider); + writers = new GoDelegator(fileManifest, symbolProvider); protocolDocumentGenerator = new ProtocolDocumentGenerator(settings, model, writers); @@ -149,7 +156,7 @@ private static ProtocolGenerator resolveProtocolGenerator( ServiceShape service, GoSettings settings ) { - // Collect all of the supported protocol generators. + // Collect all the supported protocol generators. Map generators = new HashMap<>(); for (GoIntegration integration : integrations) { for (ProtocolGenerator generator : integration.getProtocolGenerators()) { @@ -157,7 +164,7 @@ private static ProtocolGenerator resolveProtocolGenerator( } } - ServiceIndex serviceIndex = model.getKnowledge(ServiceIndex.class); + ServiceIndex serviceIndex = ServiceIndex.of(model); ShapeId protocolTrait; try { @@ -321,4 +328,10 @@ public Void serviceShape(ServiceShape shape) { }); return null; } + + @Override + public Void intEnumShape(IntEnumShape shape) { + writers.useShapeWriter(shape, writer -> new IntEnumGenerator(symbolProvider, writer, shape).run()); + return null; + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDelegator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDelegator.java index 10d0c6051..10494d411 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDelegator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDelegator.java @@ -15,56 +15,24 @@ package software.amazon.smithy.go.codegen; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import java.util.function.Consumer; import software.amazon.smithy.build.FileManifest; import software.amazon.smithy.codegen.core.Symbol; -import software.amazon.smithy.codegen.core.SymbolDependency; import software.amazon.smithy.codegen.core.SymbolProvider; import software.amazon.smithy.codegen.core.SymbolReference; -import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.Shape; /** - * Manages writers for Go files. + * Manages writers for Go files.Based off of GoWriterDelegator adding support + * for getting shape specific GoWriters. */ -public final class GoDelegator { - - private final GoSettings settings; - private final Model model; - private final FileManifest fileManifest; +public final class GoDelegator extends GoWriterDelegator { private final SymbolProvider symbolProvider; - private final Map writers = new HashMap<>(); - GoDelegator(GoSettings settings, Model model, FileManifest fileManifest, SymbolProvider symbolProvider) { - this.settings = settings; - this.model = model; - this.fileManifest = fileManifest; - this.symbolProvider = symbolProvider; - } + GoDelegator(FileManifest fileManifest, SymbolProvider symbolProvider) { + super(fileManifest); - /** - * Writes all pending writers to disk and then clears them out. - */ - void flushWriters() { - writers.forEach((filename, writer) -> fileManifest.writeFile(filename, writer.toString())); - writers.clear(); - } - - /** - * Gets all of the dependencies that have been registered in writers owned by the - * delegator. - * - * @return Returns all the dependencies. - */ - List getDependencies() { - List resolved = new ArrayList<>(); - writers.values().forEach(s -> resolved.addAll(s.getDependencies())); - return resolved; + this.symbolProvider = symbolProvider; } /** @@ -139,28 +107,4 @@ private void useShapeWriter(Symbol symbol, Consumer writerConsumer) { writerConsumer.accept(writer); writer.popState(); } - - /** - * Gets a previously created writer or creates a new one if needed - * and adds a new line if the writer already exists. - * - * @param filename Name of the file to create. - * @param writerConsumer Consumer that accepts and works with the file. - */ - void useFileWriter(String filename, String namespace, Consumer writerConsumer) { - writerConsumer.accept(checkoutWriter(filename, namespace)); - } - - private GoWriter checkoutWriter(String filename, String namespace) { - String formattedFilename = Paths.get(filename).normalize().toString(); - boolean needsNewline = writers.containsKey(formattedFilename); - - GoWriter writer = writers.computeIfAbsent(formattedFilename, f -> new GoWriter(namespace)); - - if (needsNewline) { - writer.write("\n"); - } - - return writer; - } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoEventStreamIndex.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoEventStreamIndex.java index 42cc4eab0..03b6ad4e7 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoEventStreamIndex.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoEventStreamIndex.java @@ -58,7 +58,7 @@ public GoEventStreamIndex(Model model) { eventStreamIndex.getOutputInfo(operationShape).ifPresent(eventStreamInfo -> { ShapeId eventStreamTargetId = eventStreamInfo.getEventStreamTarget().getId(); if (serviceOutputStreams.containsKey(eventStreamTargetId)) { - serviceInputStreams.get(eventStreamTargetId).add(eventStreamInfo); + serviceOutputStreams.get(eventStreamTargetId).add(eventStreamInfo); } else { TreeSet infos = new TreeSet<>( Comparator.comparing(EventStreamInfo::getOperation)); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoValueAccessUtils.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoValueAccessUtils.java index 987139e33..5f48eade4 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoValueAccessUtils.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoValueAccessUtils.java @@ -72,6 +72,8 @@ public static void writeIfNonZeroValue( } else if (container instanceof CollectionShape || container.getType() == ShapeType.MAP) { if (!ignoreEmptyString && targetShape.getType() == ShapeType.STRING) { check = String.format("if len(%s) > 0 {", operand); + } else if (!ignoreEmptyString && targetShape.getType() == ShapeType.ENUM) { + check = String.format("if len(%s) > 0 {", operand); } } else if (targetShape.hasTrait(EnumTrait.class)) { check = String.format("if len(%s) > 0 {", operand); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java index 65f784d5b..db941d2c7 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java @@ -19,10 +19,14 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.StringJoiner; import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; import java.util.logging.Logger; +import java.util.regex.Pattern; import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolContainer; @@ -31,6 +35,7 @@ import software.amazon.smithy.codegen.core.SymbolReference; import software.amazon.smithy.go.codegen.knowledge.GoUsageIndex; import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.loader.Prelude; import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.traits.DeprecatedTrait; @@ -39,7 +44,7 @@ import software.amazon.smithy.model.traits.MediaTypeTrait; import software.amazon.smithy.model.traits.RequiredTrait; import software.amazon.smithy.model.traits.StringTrait; -import software.amazon.smithy.utils.CodeWriter; +import software.amazon.smithy.utils.AbstractCodeWriter; import software.amazon.smithy.utils.StringUtils; /** @@ -49,31 +54,428 @@ * *

Use the {@code $P} formatter to refer to {@link Symbol}s using pointers where appropriate. */ -public final class GoWriter extends CodeWriter { +public final class GoWriter extends AbstractCodeWriter { private static final Logger LOGGER = Logger.getLogger(GoWriter.class.getName()); private static final int DEFAULT_DOC_WRAP_LENGTH = 80; - + private static final Pattern ARGUMENT_NAME_PATTERN = Pattern.compile("\\$([a-z][a-zA-Z_0-9]+)(:\\w)?"); private final String fullPackageName; private final ImportDeclarations imports = new ImportDeclarations(); private final List dependencies = new ArrayList<>(); + private final boolean innerWriter; private int docWrapLength = DEFAULT_DOC_WRAP_LENGTH; + private AbstractCodeWriter packageDocs; - private CodeWriter packageDocs; - + /** + * Initializes the GoWriter for the package and filename to be written to. + * + * @param fullPackageName package and filename to be written to. + */ public GoWriter(String fullPackageName) { this.fullPackageName = fullPackageName; + this.innerWriter = false; + init(); + } + + private GoWriter(String fullPackageName, boolean innerWriter) { + this.fullPackageName = fullPackageName; + this.innerWriter = innerWriter; + init(); + } + + private void init() { trimBlankLines(); trimTrailingSpaces(); setIndentText("\t"); putFormatter('T', new GoSymbolFormatter()); putFormatter('P', new PointableGoSymbolFormatter()); + putFormatter('W', new GoWritableInjector()); + + if (!innerWriter) { + packageDocs = new GoWriter(this.fullPackageName, true); + } + } + + // TODO figure out better way to annotate where the failure occurs, check templates and args + // TODO to try to find programming bugs. + + /** + * Returns a Writable for the string and args to be composed inline to another writer's contents. + * + * @param contents string to write. + * @param args Arguments to use when evaluating the contents string. + * @return Writable to be evaluated. + */ + @SafeVarargs + public static Writable goTemplate(String contents, Map... args) { + validateTemplateArgsNotNull(args); + return (GoWriter w) -> { + w.writeGoTemplate(contents, args); + }; + } + + /** + * Returns a Writable that can later be invoked to write the contents as template + * as a code block instead of single content of text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param fn closure to write + */ + public static Writable goBlockTemplate( + String beforeNewLine, + String afterNewLine, + Consumer fn + ) { + return goBlockTemplate(beforeNewLine, afterNewLine, new Map[0], fn); + } + + /** + * Returns a Writable that can later be invoked to write the contents as template + * as a code block instead of single content of text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param args1 template arguments + * @param fn closure to write + */ + public static Writable goBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map args1, + Consumer fn + ) { + return goBlockTemplate(beforeNewLine, afterNewLine, new Map[]{args1}, fn); + } + + /** + * Returns a Writable that can later be invoked to write the contents as template + * as a code block instead of single content of text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param args1 template arguments + * @param args2 template arguments + * @param fn closure to write + */ + public static Writable goBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map args1, + Map args2, + Consumer fn + ) { + return goBlockTemplate(beforeNewLine, afterNewLine, new Map[]{args1, args2}, fn); + } + + /** + * Returns a Writable that can later be invoked to write the contents as template + * as a code block instead of single content of text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param args1 template arguments + * @param args2 template arguments + * @param args3 template arguments + * @param fn closure to write + */ + public static Writable goBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map args1, + Map args2, + Map args3, + Consumer fn + ) { + return goBlockTemplate(beforeNewLine, afterNewLine, new Map[]{args1, args2, args3}, fn); + } + + /** + * Returns a Writable that can later be invoked to write the contents as template + * as a code block instead of single content of text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param args1 template arguments + * @param args2 template arguments + * @param args3 template arguments + * @param args4 template arguments + * @param fn closure to write + */ + public static Writable goBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map args1, + Map args2, + Map args3, + Map args4, + Consumer fn + ) { + return goBlockTemplate(beforeNewLine, afterNewLine, new Map[]{args1, args2, args3, args4}, fn); + } + + /** + * Returns a Writable that can later be invoked to write the contents as template + * as a code block instead of single content of text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param args1 template arguments + * @param args2 template arguments + * @param args3 template arguments + * @param args4 template arguments + * @param args5 template arguments + * @param fn closure to write + */ + public static Writable goBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map args1, + Map args2, + Map args3, + Map args4, + Map args5, + Consumer fn + ) { + return goBlockTemplate(beforeNewLine, afterNewLine, new Map[]{args1, args2, args3, args4, args5}, fn); + } + + /** + * Returns a Writable that can later be invoked to write the contents as template + * as a code block instead of single content of text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param args template arguments + * @param fn closure to write + */ + public static Writable goBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map[] args, + Consumer fn + ) { + validateTemplateArgsNotNull(args); + return (GoWriter w) -> { + w.writeGoBlockTemplate(beforeNewLine, afterNewLine, args, fn); + }; + } + + /** + * Returns a Writable that does nothing. + * + * @return Writable that does nothing + */ + public static Writable emptyGoTemplate() { + return (GoWriter w) -> { + }; + } + + /** + * Writes the contents and arguments as a template to the writer. + * + * @param contents string to write + * @param args Arguments to use when evaluating the contents string. + */ + @SafeVarargs + public final void writeGoTemplate(String contents, Map... args) { + withTemplate(contents, args, (template) -> { + try { + write(contents); + } catch (Exception e) { + throw new CodegenException("Failed to render template\n" + contents + "\nReason: " + e.getMessage(), e); + } + }); + } - packageDocs = new CodeWriter(); - packageDocs.trimBlankLines(); - packageDocs.trimTrailingSpaces(); - packageDocs.setIndentText("\t"); + /** + * Writes the contents as template as a code block instead of single content fo text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param fn closure to write + */ + public void writeGoBlockTemplate( + String beforeNewLine, + String afterNewLine, + Consumer fn + ) { + writeGoBlockTemplate(beforeNewLine, afterNewLine, new Map[0], fn); + } + + /** + * Writes the contents as template as a code block instead of single content fo text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param arg1 first map argument + * @param fn closure to write + */ + public void writeGoBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map arg1, + Consumer fn + ) { + writeGoBlockTemplate(beforeNewLine, afterNewLine, new Map[]{arg1}, fn); + } + + /** + * Writes the contents as template as a code block instead of single content fo text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param arg1 first map argument + * @param arg2 second map argument + * @param fn closure to write + */ + public void writeGoBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map arg1, + Map arg2, + Consumer fn + ) { + writeGoBlockTemplate(beforeNewLine, afterNewLine, new Map[]{arg1, arg2}, fn); + } + + /** + * Writes the contents as template as a code block instead of single content fo text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param arg1 first map argument + * @param arg2 second map argument + * @param arg3 third map argument + * @param fn closure to write + */ + public void writeGoBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map arg1, + Map arg2, + Map arg3, + Consumer fn + ) { + writeGoBlockTemplate(beforeNewLine, afterNewLine, new Map[]{arg1, arg2, arg3}, fn); + } + + /** + * Writes the contents as template as a code block instead of single content fo text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param arg1 first map argument + * @param arg2 second map argument + * @param arg3 third map argument + * @param arg4 forth map argument + * @param fn closure to write + */ + public void writeGoBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map arg1, + Map arg2, + Map arg3, + Map arg4, + Consumer fn + ) { + writeGoBlockTemplate(beforeNewLine, afterNewLine, new Map[]{arg1, arg2, arg3, arg4}, fn); + } + + /** + * Writes the contents as template as a code block instead of single content fo text. + * + * @param beforeNewLine text before new line + * @param afterNewLine text after new line + * @param arg1 first map argument + * @param arg2 second map argument + * @param arg3 third map argument + * @param arg4 forth map argument + * @param arg5 forth map argument + * @param fn closure to write + */ + public void writeGoBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map arg1, + Map arg2, + Map arg3, + Map arg4, + Map arg5, + Consumer fn + ) { + writeGoBlockTemplate(beforeNewLine, afterNewLine, new Map[]{arg1, arg2, arg3, arg4, arg5}, fn); + } + + public void writeGoBlockTemplate( + String beforeNewLine, + String afterNewLine, + Map[] args, + Consumer fn + ) { + withTemplate(beforeNewLine, args, (header) -> { + conditionalBlock(header, afterNewLine, true, new Object[0], fn); + }); + } + + private void withTemplate( + String template, + Map[] argMaps, + Consumer fn + ) { + pushState(); + for (var args : argMaps) { + putContext(args); + } + validateContext(template); + fn.accept(template); + popState(); + } + + private GoWriter conditionalBlock( + String beforeNewLine, + String afterNewLine, + boolean conditional, + Object[] args, + Consumer fn + ) { + if (conditional) { + openBlock(beforeNewLine.trim(), args); + } + + fn.accept(this); + + if (conditional) { + closeBlock(afterNewLine.trim()); + } + + return this; + } + + private static void validateTemplateArgsNotNull(Map[] argMaps) { + for (var args : argMaps) { + args.forEach((k, v) -> { + if (v == null) { + throw new CodegenException("Template argument " + k + " cannot be null"); + } + }); + } + } + + private void validateContext(String template) { + var matcher = ARGUMENT_NAME_PATTERN.matcher(template); + + while (matcher.find()) { + var keyName = matcher.group(1); + var value = getContext(keyName); + if (value == null) { + throw new CodegenException( + "Go template expected " + keyName + " but was not present in context scope." + + " Template: \n" + template); + } + } } /** @@ -168,6 +570,11 @@ public GoWriter addImport(Symbol symbol, String packageAlias, SymbolReference.Co return this; } + private GoWriter addImports(GoWriter other) { + this.imports.addImports(other.imports); + return this; + } + private boolean isExternalNamespace(String namespace) { return !StringUtils.isBlank(namespace) && !namespace.equals(fullPackageName); } @@ -209,6 +616,11 @@ public GoWriter addDependency(SymbolDependencyContainer dependencies) { return this; } + private GoWriter addDependencies(GoWriter other) { + this.dependencies.addAll(other.getDependencies()); + return this; + } + Collection getDependencies() { return dependencies; } @@ -219,7 +631,7 @@ Collection getDependencies() { * @param runnable Runnable that handles actually writing docs with the writer. * @return Returns the writer. */ - private void writeDocs(CodeWriter writer, Runnable runnable) { + private void writeDocs(AbstractCodeWriter writer, Runnable runnable) { writer.pushState("docs"); writer.setNewlinePrefix("// "); runnable.run(); @@ -227,7 +639,7 @@ private void writeDocs(CodeWriter writer, Runnable runnable) { writer.popState(); } - private void writeDocs(CodeWriter writer, int docWrapLength, String docs) { + private void writeDocs(AbstractCodeWriter writer, int docWrapLength, String docs) { String wrappedDoc = StringUtils.wrap(DocumentationConverter.convert(docs), docWrapLength); writeDocs(writer, () -> writer.write(wrappedDoc.replace("$", "$$"))); } @@ -245,6 +657,24 @@ public GoWriter writeDocs(String docs) { return this; } + /** + * Writes documentation from an arbitrary Writable. + * + * @param writable Contents to be written. + * @return Returns the writer. + */ + public GoWriter writeRenderedDocs(Writable writable) { + writeRenderedDocs(this, docWrapLength, writable); + return this; + } + + private void writeRenderedDocs(AbstractCodeWriter writer, int docWrapLength, Writable writable) { + var innerWriter = new GoWriter(fullPackageName, true); + writable.accept(innerWriter); + var wrappedDocs = StringUtils.wrap(innerWriter.toString().trim(), docWrapLength); + writeDocs(writer, () -> writer.write(wrappedDocs.replace("$", "$$"))); + } + /** * Writes the doc to the Go package docs that are written prior to the go package statement. * @@ -348,25 +778,41 @@ boolean writeMemberDocs(Model model, MemberShape member) { } Optional deprecatedTrait = member.getMemberTrait(model, DeprecatedTrait.class); - if (deprecatedTrait.isPresent()) { + if (member.getTrait(DeprecatedTrait.class).isPresent() || isTargetDeprecated(model, member)) { if (hasDocs) { writeDocs(""); } final String defaultMessage = "This member has been deprecated."; - writeDocs("Deprecated: " + deprecatedTrait.get().getMessage().map(s -> { - if (s.length() == 0) { - return defaultMessage; - } - return s; - }).orElse(defaultMessage)); + String message = defaultMessage; + if (deprecatedTrait.isPresent()) { + message = deprecatedTrait.get().getMessage().map(s -> { + if (s.length() == 0) { + return defaultMessage; + } + return s; + }).orElse(defaultMessage); + } + writeDocs("Deprecated: " + message); } return hasDocs; } + private boolean isTargetDeprecated(Model model, MemberShape member) { + return model.expectShape(member.getTarget()).getTrait(DeprecatedTrait.class).isPresent() + // don't consider deprecated prelude shapes (like PrimitiveBoolean) + && !Prelude.isPreludeShape(member.getTarget()); + } + @Override public String toString() { String contents = super.toString(); + + if (innerWriter) { + return contents; + } + + String[] packageParts = fullPackageName.split("/"); String header = String.format("// Code generated by smithy-go-codegen DO NOT EDIT.%n%n"); @@ -474,4 +920,63 @@ private boolean isPointer(Object type) { } } } + + class GoWritableInjector extends GoSymbolFormatter { + @Override + public String apply(Object type, String indent) { + if (!(type instanceof Writable)) { + throw new CodegenException( + "expect Writable for GoWriter W injector, but got " + type); + } + var innerWriter = new GoWriter(fullPackageName, true); + ((Writable) type).accept(innerWriter); + addImports(innerWriter); + addDependencies(innerWriter); + return innerWriter.toString().trim(); + } + } + + public interface Writable extends Consumer { + } + + /** + * Chains together multiple Writables that can be composed into one Writable. + */ + public static final class ChainWritable { + private final List writables; + + public ChainWritable() { + writables = new ArrayList<>(); + } + + public ChainWritable add(GoWriter.Writable writable) { + writables.add(writable); + return this; + } + + public ChainWritable add(Optional value, Function fn) { + value.ifPresent(t -> writables.add(fn.apply(t))); + return this; + } + + public ChainWritable add(boolean include, GoWriter.Writable writable) { + if (!include) { + writables.add(writable); + } + return this; + } + + public GoWriter.Writable compose() { + return (GoWriter writer) -> { + var hasPrevious = false; + for (GoWriter.Writable writable : writables) { + if (hasPrevious) { + writer.write(""); + } + hasPrevious = true; + writer.write("$W", writable); + } + }; + } + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriterDelegator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriterDelegator.java new file mode 100644 index 000000000..0c82b3a16 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriterDelegator.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen; + +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import software.amazon.smithy.build.FileManifest; +import software.amazon.smithy.codegen.core.SymbolDependency; + +public class GoWriterDelegator { + private final FileManifest fileManifest; + private final Map writers = new HashMap<>(); + + public GoWriterDelegator(FileManifest fileManifest) { + this.fileManifest = fileManifest; + } + + /** + * Writes all pending writers to disk and then clears them out. + */ + public void flushWriters() { + writers.forEach((filename, writer) -> fileManifest.writeFile(filename, writer.toString())); + writers.clear(); + } + + /** + * Gets all the dependencies that have been registered in writers owned by the + * delegator. + * + * @return Returns all the dependencies. + */ + public List getDependencies() { + List resolved = new ArrayList<>(); + writers.values().forEach(s -> resolved.addAll(s.getDependencies())); + return resolved; + } + + /** + * Gets a previously created writer or creates a new one if needed + * and adds a new line if the writer already exists. + * + * @param filename Name of the file to create. + * @param namespace Namespace of the file's content. + * @param writerConsumer Consumer that accepts and works with the file. + */ + public void useFileWriter(String filename, String namespace, Consumer writerConsumer) { + writerConsumer.accept(checkoutWriter(filename, namespace)); + } + + GoWriter checkoutWriter(String filename, String namespace) { + String formattedFilename = Paths.get(filename).normalize().toString(); + boolean needsNewline = writers.containsKey(formattedFilename); + + GoWriter writer = writers.computeIfAbsent(formattedFilename, f -> new GoWriter(namespace)); + + if (needsNewline) { + writer.write("\n"); + } + + return writer; + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ImportDeclarations.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ImportDeclarations.java index c9c1d3e86..e72a6a05a 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ImportDeclarations.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ImportDeclarations.java @@ -27,16 +27,28 @@ final class ImportDeclarations { private final Map imports = new TreeMap<>(); - ImportDeclarations addImport(String packageName, String alias) { - String importedName = CodegenUtils.getDefaultPackageImportName(packageName); + ImportDeclarations addImport(String importPath, String alias) { + String importAlias = CodegenUtils.getDefaultPackageImportName(importPath); if (!StringUtils.isBlank(alias)) { if (alias.equals(".")) { // Global imports are generally a bad practice. - throw new CodegenException("Globally importing packages is forbidden: " + packageName); + throw new CodegenException("Globally importing packages is forbidden: " + importPath); } - importedName = alias; + importAlias = alias; } - imports.putIfAbsent(importedName, packageName); + // Ensure that multiple packages cannot be imported with the same name. + if (imports.containsKey(importAlias) && !imports.get(importAlias).equals(importPath)) { + throw new CodegenException("Import name collision: " + importAlias + + ". Previous: " + imports.get(importAlias) + "New: " + importPath); + } + imports.putIfAbsent(importAlias, importPath); + return this; + } + + ImportDeclarations addImports(ImportDeclarations other) { + other.imports.forEach((importAlias, importPath) -> { + addImport(importPath, importAlias); + }); return this; } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/IntEnumGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/IntEnumGenerator.java new file mode 100644 index 000000000..7b101c492 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/IntEnumGenerator.java @@ -0,0 +1,98 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen; + +import java.util.LinkedHashSet; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.model.shapes.IntEnumShape; +import software.amazon.smithy.model.shapes.MemberShape; +import software.amazon.smithy.model.traits.DocumentationTrait; +import software.amazon.smithy.model.traits.EnumValueTrait; +import software.amazon.smithy.utils.StringUtils; + +/** + * Renders intEnums and their constants. + */ +final class IntEnumGenerator implements Runnable { + private static final Logger LOGGER = Logger.getLogger(IntEnumGenerator.class.getName()); + + private final SymbolProvider symbolProvider; + private final GoWriter writer; + private final IntEnumShape shape; + + IntEnumGenerator(SymbolProvider symbolProvider, GoWriter writer, IntEnumShape shape) { + this.symbolProvider = symbolProvider; + this.writer = writer; + this.shape = shape; + } + + @Override + public void run() { + Symbol symbol = symbolProvider.toSymbol(shape); + + writer.write("type $L int32", symbol.getName()).write(""); + + writer.writeDocs(String.format("Enum values for %s", symbol.getName())); + Set constants = new LinkedHashSet<>(); + writer.openBlock("const (", ")", () -> { + for (Map.Entry entry : shape.getAllMembers().entrySet()) { + StringBuilder labelBuilder = new StringBuilder(symbol.getName()); + String name = entry.getKey(); + + for (String part : name.split("(?U)[\\W_]")) { + if (part.matches(".*[a-z].*") && part.matches(".*[A-Z].*")) { + // Mixed case names should not be changed other than first letter capitalized. + labelBuilder.append(StringUtils.capitalize(part)); + } else { + // For all non-mixed case parts title case first letter, followed by all other lower cased. + labelBuilder.append(StringUtils.capitalize(part.toLowerCase(Locale.US))); + } + } + String label = labelBuilder.toString(); + + // If camel-casing would cause a conflict, don't camel-case this enum value. + if (constants.contains(label)) { + LOGGER.warning(String.format( + "Multiple enums resolved to the same name, `%s`, using unaltered value for: %s", + label, name)); + label = name; + } + constants.add(label); + + entry.getValue().getTrait(DocumentationTrait.class) + .ifPresent(trait -> writer.writeDocs(trait.getValue())); + writer.write("$L $L = $L", label, symbol.getName(), + entry.getValue().expectTrait(EnumValueTrait.class).expectIntValue()); + } + }).write(""); + + writer.writeDocs(String.format("Values returns all known values for %s. Note that this can be expanded in the " + + "future, and so it is only as up to date as the client.%n%nThe ordering of this slice is not " + + "guaranteed to be stable across updates.", symbol.getName())); + writer.openBlock("func ($L) Values() []$L {", "}", symbol.getName(), symbol.getName(), () -> { + writer.openBlock("return []$L{", "}", symbol.getName(), () -> { + for (Map.Entry entry : shape.getEnumValues().entrySet()) { + writer.write("$L,", entry.getValue()); + } + }); + }); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ManifestWriter.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ManifestWriter.java index 6a9c33256..a33713ee2 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ManifestWriter.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ManifestWriter.java @@ -37,18 +37,32 @@ import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.node.StringNode; import software.amazon.smithy.model.traits.UnstableTrait; +import software.amazon.smithy.utils.BuilderRef; +import software.amazon.smithy.utils.SmithyBuilder; /** - * Generates a manifest description of the generated code, minimum go version, and minimum dependencies required. + * Generates a manifest description of the generated code, minimum go version, + * and minimum dependencies required. */ public final class ManifestWriter { private static final String GENERATED_JSON = "generated.json"; - private ManifestWriter() { + private final String moduleName; + private final FileManifest fileManifest; + private final List dependencies; + private final Optional minimumGoVersion; + private final boolean isUnstable; + + private ManifestWriter(Builder builder) { + moduleName = SmithyBuilder.requiredState("moduleName", builder.moduleName); + fileManifest = SmithyBuilder.requiredState("fileManifest", builder.fileManifest); + dependencies = builder.dependencies.copy(); + minimumGoVersion = builder.minimumGoVersion; + isUnstable = builder.isUnstable; } /** - * Write the manifest description of the generated code. + * Write the manifest description of the Smithy model based generated source code. * * @param settings the go settings * @param model the smithy model @@ -61,6 +75,20 @@ public static void writeManifest( FileManifest fileManifest, List dependencies ) { + builder() + .moduleName(settings.getModuleName()) + .fileManifest(fileManifest) + .dependencies(dependencies) + .isUnstable(settings.getService(model).getTrait(UnstableTrait.class).isPresent()) + .build() + .writeManifest(); + } + + + /** + * Write the manifest description of the generated code. + */ + public void writeManifest() { Path manifestFile = fileManifest.getBaseDir().resolve(GENERATED_JSON); if (Files.exists(manifestFile)) { @@ -72,31 +100,28 @@ public static void writeManifest( } fileManifest.addFile(manifestFile); - Node generatedJson = buildManifestFile(settings, model, fileManifest, dependencies); + Node generatedJson = buildManifestFile(); fileManifest.writeFile(manifestFile.toString(), Node.prettyPrintJson(generatedJson) + "\n"); - } - private static Node buildManifestFile( - GoSettings settings, - Model model, - FileManifest fileManifest, - List dependencies - ) { + } + private Node buildManifestFile() { List nonStdLib = new ArrayList<>(); - Optional minStandard = Optional.empty(); + Optional minimumGoVersion = this.minimumGoVersion; for (SymbolDependency dependency : dependencies) { if (!dependency.getDependencyType().equals(GoDependency.Type.STANDARD_LIBRARY.toString())) { nonStdLib.add(dependency); - } else { - if (minStandard.isPresent()) { - if (minStandard.get().getVersion().compareTo(dependency.getVersion()) < 0) { - minStandard = Optional.of(dependency); - } - } else { - minStandard = Optional.of(dependency); + continue; + } + + var otherVersion = dependency.getVersion(); + if (minimumGoVersion.isPresent()) { + if (minimumGoVersion.get().compareTo(otherVersion) < 0) { + minimumGoVersion = Optional.of(otherVersion); } + } else { + minimumGoVersion = Optional.of(otherVersion); } } @@ -106,8 +131,7 @@ private static Node buildManifestFile( Map dependencyNodes = new HashMap<>(); for (Map.Entry entry : minimumDependencies.entrySet()) { - dependencyNodes.put(StringNode.from(entry.getKey()), - StringNode.from(entry.getValue())); + dependencyNodes.put(StringNode.from(entry.getKey()), StringNode.from(entry.getValue())); } Collection generatedFiles = new ArrayList<>(); @@ -117,13 +141,12 @@ private static Node buildManifestFile( } generatedFiles = generatedFiles.stream().sorted().collect(Collectors.toList()); - manifestNodes.put(StringNode.from("module"), StringNode.from(settings.getModuleName())); - minStandard.ifPresent(symbolDependency -> - manifestNodes.put(StringNode.from("go"), StringNode.from(symbolDependency.getVersion()))); + manifestNodes.put(StringNode.from("module"), StringNode.from(moduleName)); + minimumGoVersion.ifPresent(version -> manifestNodes.put(StringNode.from("go"), + StringNode.from(version))); manifestNodes.put(StringNode.from("dependencies"), ObjectNode.objectNode(dependencyNodes)); manifestNodes.put(StringNode.from("files"), ArrayNode.fromStrings(generatedFiles)); - manifestNodes.put(StringNode.from("unstable"), - BooleanNode.from(settings.getService(model).getTrait(UnstableTrait.class).isPresent())); + manifestNodes.put(StringNode.from("unstable"), BooleanNode.from(isUnstable)); return ObjectNode.objectNode(manifestNodes).withDeepSortedKeys(); } @@ -131,11 +154,52 @@ private static Node buildManifestFile( private static Map gatherMinimumDependencies( Stream symbolStream ) { - return SymbolDependency.gatherDependencies(symbolStream, GoDependency::mergeByMinimumVersionSelection) - .entrySet().stream() - .flatMap(entry -> entry.getValue().entrySet().stream()) - .collect(Collectors.toMap( - Map.Entry::getKey, entry -> entry.getValue().getVersion(), (a, b) -> b, TreeMap::new)); + return SymbolDependency.gatherDependencies(symbolStream, + GoDependency::mergeByMinimumVersionSelection).entrySet().stream().flatMap( + entry -> entry.getValue().entrySet().stream()).collect( + Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getVersion(), (a, b) -> b, TreeMap::new)); + } + + public static Builder builder() { + return new Builder(); } + public static class Builder implements SmithyBuilder { + private String moduleName; + private FileManifest fileManifest; + private final BuilderRef> dependencies = BuilderRef.forList(); + private Optional minimumGoVersion = Optional.empty(); + private boolean isUnstable; + + public Builder moduleName(String moduleName) { + this.moduleName = moduleName; + return this; + } + + public Builder fileManifest(FileManifest fileManifest) { + this.fileManifest = fileManifest; + return this; + } + + public Builder dependencies(List dependencies) { + this.dependencies.clear(); + this.dependencies.get().addAll(dependencies); + return this; + } + + public Builder minimumGoVersion(String minimumGoVersion) { + this.minimumGoVersion = Optional.of(minimumGoVersion); + return this; + } + + public Builder isUnstable(boolean isUnstable) { + this.isUnstable = isUnstable; + return this; + } + + @Override + public ManifestWriter build() { + return new ManifestWriter(this); + } + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java index 5df9ba2ba..661554719 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java @@ -289,6 +289,10 @@ protected void writeScalarPointerInline(GoWriter writer, MemberShape member, Nod funcName = "String"; break; + case ENUM: + funcName = target.getId().getName(); + break; + case TIMESTAMP: funcName = "Time"; break; @@ -360,6 +364,12 @@ protected void writeScalarValueInline(GoWriter writer, MemberShape member, Node closing = ")"; } break; + case ENUM: + // Enum are not pointers, but string alias values + Symbol enumSymbol = symbolProvider.toSymbol(target); + writer.writeInline("$T(", enumSymbol); + closing = ")"; + break; default: break; @@ -694,6 +704,10 @@ public Void stringNode(StringNode node) { writer.writeInline("$S", node.getValue()); break; + case ENUM: + writer.writeInline("$S", node.getValue()); + break; + case BIG_INTEGER: writeInlineBigIntegerInit(writer, node.getValue()); break; diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java index b33f73e3c..e4de6db55 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java @@ -60,6 +60,7 @@ public final class SmithyGoDependency { public static final GoDependency SMITHY_DOCUMENT = smithy("document", "smithydocument"); public static final GoDependency SMITHY_DOCUMENT_JSON = smithy("document/json", "smithydocumentjson"); public static final GoDependency SMITHY_SYNC = smithy("sync", "smithysync"); + public static final GoDependency SMITHY_AUTH_BEARER = smithy("auth/bearer"); public static final GoDependency GO_CMP = goCmp("cmp"); public static final GoDependency GO_CMP_OPTIONS = goCmp("cmp/cmpopts"); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java index 8f821d56a..04d47bc14 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java @@ -41,6 +41,7 @@ import software.amazon.smithy.model.shapes.DocumentShape; import software.amazon.smithy.model.shapes.DoubleShape; import software.amazon.smithy.model.shapes.FloatShape; +import software.amazon.smithy.model.shapes.IntEnumShape; import software.amazon.smithy.model.shapes.IntegerShape; import software.amazon.smithy.model.shapes.ListShape; import software.amazon.smithy.model.shapes.LongShape; @@ -117,6 +118,7 @@ final class SymbolVisitor implements SymbolProvider, ShapeVisitor { // Reserved words that only apply to error members. ReservedWords reservedErrorMembers = new ReservedWordsBuilder() .put("ErrorCode", "ErrorCode_") + .put("ErrorMessage", "ErrorMessage_") .put("ErrorFault", "ErrorFault_") .put("Unwrap", "Unwrap_") .put("Error", "Error_") @@ -503,4 +505,12 @@ public Symbol memberShape(MemberShape member) { public Symbol timestampShape(TimestampShape shape) { return symbolBuilderFor(shape, "Time", SmithyGoDependency.TIME).build(); } + + @Override + public Symbol intEnumShape(IntEnumShape shape) { + String name = getDefaultShapeName(shape); + return symbolBuilderFor(shape, name, typesPackageName) + .definitionFile("./types/enums.go") + .build(); + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java index dfeb669a8..b04e1442d 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java @@ -476,11 +476,20 @@ protected abstract void writeMiddlewareDocumentSerializerDelegator( * @param payloadShape the payload shape. */ protected void writeSetPayloadShapeHeader(GoWriter writer, Shape payloadShape) { + writer.pushState(); + + writer.putContext("withIsDefaultContentType", SymbolUtils.createValueSymbolBuilder( + "SetIsContentTypeDefaultValue", SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build()); + writer.putContext("payloadMediaType", getPayloadShapeMediaType(payloadShape)); + writer.write(""" if !restEncoder.HasHeader("Content-Type") { - restEncoder.SetHeader("Content-Type").String($S) + ctx = $withIsDefaultContentType:T(ctx, true) + restEncoder.SetHeader("Content-Type").String($payloadMediaType:S) } - """, getPayloadShapeMediaType(payloadShape)); + """); + + writer.popState(); } /** @@ -511,24 +520,24 @@ protected void writeMiddlewarePayloadSerializerDelegator( Shape payloadShape = model.expectShape(memberShape.getTarget()); if (payloadShape.hasTrait(StreamingTrait.class)) { + writeSetPayloadShapeHeader(writer, payloadShape); GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, memberShape, "input", (s) -> { - writeSetPayloadShapeHeader(writer, payloadShape); writer.write("payload := $L", s); writeSetStream(writer, "payload"); }); } else if (payloadShape.isBlobShape()) { + writeSetPayloadShapeHeader(writer, payloadShape); GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, memberShape, "input", (s) -> { - writeSetPayloadShapeHeader(writer, payloadShape); writer.addUseImports(SmithyGoDependency.BYTES); writer.write("payload := bytes.NewReader($L)", s); writeSetStream(writer, "payload"); }); } else if (payloadShape.isStringShape()) { + writeSetPayloadShapeHeader(writer, payloadShape); GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, memberShape, "input", (s) -> { - writeSetPayloadShapeHeader(writer, payloadShape); writer.addUseImports(SmithyGoDependency.STRINGS); if (payloadShape.hasTrait(EnumTrait.class)) { writer.write("payload := strings.NewReader(string($L))", s); @@ -744,6 +753,10 @@ private void writeHttpBindingSetter( operand = targetShape.hasTrait(EnumTrait.class) ? "string(" + operand + ")" : operand; locationEncoder.accept(writer, "String(" + operand + ")"); break; + case ENUM: + operand = "string(" + operand + ")"; + locationEncoder.accept(writer, "String(" + operand + ")"); + break; case TIMESTAMP: generateHttpBindingTimestampSerializer(model, writer, memberShape, location, operand, locationEncoder); break; @@ -1152,6 +1165,9 @@ private String generateHttpHeaderValue( return "string(b)"; } return operand; + case ENUM: + value = String.format("types.%s(%s)", targetShape.getId().getName(), operand); + return value; case BOOLEAN: writer.addUseImports(SmithyGoDependency.STRCONV); writer.write("vv, err := strconv.ParseBool($L)", operand); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java index b25147f6c..21c9ec6bf 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java @@ -184,12 +184,12 @@ private void generateOperationSerializer(GenerationContext context, OperationSha writer.openBlock("if err != nil {", "}", () -> { writer.write("return out, metadata, &smithy.SerializationError{Err: err}"); }); - writeRequestHeaders(context, operation, writer); - writer.write(""); - Optional inputInfo = EventStreamIndex.of(model).getInputInfo(operation); + Optional optionalEventStreamInfo = EventStreamIndex.of(model).getInputInfo(operation); // Skip and Handle Input Event Stream Setup Separately - if (inputInfo.isEmpty()) { + if (optionalEventStreamInfo.isEmpty()) { + writeRequestHeaders(context, operation, writer); + writer.write(""); // delegate the setup and usage of the document serializer function for the protocol serializeInputDocument(context, operation); // Skipping calling serializer method for the input shape is responsibility of the @@ -198,7 +198,8 @@ private void generateOperationSerializer(GenerationContext context, OperationSha serializingDocumentShapes.add(ProtocolUtils.expectInput(model, operation)); } } else { - writeOperationSerializerMiddlewareEventStreamSetup(context, inputInfo.get()); + writeDefaultHeaders(context, operation, writer); + writeOperationSerializerMiddlewareEventStreamSetup(context, optionalEventStreamInfo.get()); } writer.write(""); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java index b068a731b..4f35befde 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java @@ -201,7 +201,7 @@ private void writePaginator( """); pageSizeMember.ifPresent(memberShape -> { - if (pointableIndex.isPointable(model.expectShape(memberShape.getTarget()))) { + if (pointableIndex.isPointable(memberShape)) { writer.write(""" var limit $P if p.options.Limit > 0 { diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/auth/HttpBearerAuth.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/auth/HttpBearerAuth.java new file mode 100644 index 000000000..5ec697581 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/auth/HttpBearerAuth.java @@ -0,0 +1,194 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.integration.auth; + +import java.util.List; +import java.util.Map; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoDelegator; +import software.amazon.smithy.go.codegen.GoSettings; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.go.codegen.integration.ConfigField; +import software.amazon.smithy.go.codegen.integration.ConfigFieldResolver; +import software.amazon.smithy.go.codegen.integration.GoIntegration; +import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; +import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.ServiceIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.traits.HttpBearerAuthTrait; +import software.amazon.smithy.model.traits.OptionalAuthTrait; +import software.amazon.smithy.model.traits.Trait; +import software.amazon.smithy.utils.ListUtils; + +/** + * Integration to add support for httpBearerAuth authentication scheme to an API client. + */ +public class HttpBearerAuth implements GoIntegration { + + public static final String TOKEN_PROVIDER_OPTION_NAME = "BearerAuthTokenProvider"; + private static final String SIGNER_OPTION_NAME = "BearerAuthSigner"; + private static final String NEW_DEFAULT_SIGNER_NAME = "newDefault" + SIGNER_OPTION_NAME; + private static final String SIGNER_RESOLVER_NAME = "resolve" + SIGNER_OPTION_NAME; + private static final String REGISTER_MIDDLEWARE_NAME = "add" + SIGNER_OPTION_NAME + "Middleware"; + + @Override + public void writeAdditionalFiles( + GoSettings settings, + Model model, + SymbolProvider symbolProvider, + GoDelegator goDelegator + ) { + var service = settings.getService(model); + if (!isSupportedAuthentication(model, service)) { + return; + } + + goDelegator.useShapeWriter(service, (writer) -> { + writeMiddlewareRegister(writer); + writeSignerConfigFieldResolver(writer); + writeNewSignerFunc(writer); + }); + } + + private void writeMiddlewareRegister(GoWriter writer) { + writer.pushState(); + + writer.putContext("funcName", REGISTER_MIDDLEWARE_NAME); + writer.putContext("stack", SymbolUtils.createValueSymbolBuilder("Stack", + SmithyGoDependency.SMITHY_MIDDLEWARE).build()); + writer.putContext("addMiddleware", SymbolUtils.createValueSymbolBuilder("AddAuthenticationMiddleware", + SmithyGoDependency.SMITHY_AUTH_BEARER).build()); + writer.putContext("signerOption", SIGNER_OPTION_NAME); + writer.putContext("providerOption", TOKEN_PROVIDER_OPTION_NAME); + + writer.write(""" + func $funcName:L(stack *$stack:T, o Options) error { + return $addMiddleware:T(stack, o.$signerOption:L, o.$providerOption:L) + } + """); + + writer.popState(); + } + + private void writeSignerConfigFieldResolver(GoWriter writer) { + writer.pushState(); + + writer.putContext("funcName", SIGNER_RESOLVER_NAME); + writer.putContext("signerOption", SIGNER_OPTION_NAME); + writer.putContext("newDefaultSigner", NEW_DEFAULT_SIGNER_NAME); + + writer.write(""" + func $funcName:L(o *Options) { + if o.$signerOption:L != nil { + return + } + o.$signerOption:L = $newDefaultSigner:L(*o) + } + """); + + writer.popState(); + } + + private void writeNewSignerFunc(GoWriter writer) { + writer.pushState(); + + writer.putContext("funcName", NEW_DEFAULT_SIGNER_NAME); + writer.putContext("signerInterface", SymbolUtils.createValueSymbolBuilder("Signer", + SmithyGoDependency.SMITHY_AUTH_BEARER).build()); + + // TODO this is HTTP specific, should be based on protocol/transport of API. + writer.putContext("newDefaultSigner", SymbolUtils.createValueSymbolBuilder("NewSignHTTPSMessage", + SmithyGoDependency.SMITHY_AUTH_BEARER).build()); + + writer.write(""" + func $funcName:L(o Options) $signerInterface:T { + return $newDefaultSigner:T() + } + """); + + writer.popState(); + } + + @Override + public List getClientPlugins() { + return ListUtils.of( + RuntimeClientPlugin.builder() + .servicePredicate(HttpBearerAuth::isSupportedAuthentication) + .addConfigField(ConfigField.builder() + .name(TOKEN_PROVIDER_OPTION_NAME) + .type(SymbolUtils.createValueSymbolBuilder("TokenProvider", + SmithyGoDependency.SMITHY_AUTH_BEARER).build()) + .documentation("Bearer token value provider") + .build()) + .build(), + RuntimeClientPlugin.builder() + .servicePredicate(HttpBearerAuth::isSupportedAuthentication) + .addConfigField(ConfigField.builder() + .name(SIGNER_OPTION_NAME) + .type(SymbolUtils.createValueSymbolBuilder("Signer", + SmithyGoDependency.SMITHY_AUTH_BEARER).build()) + .documentation("Signer for authenticating requests with bearer auth") + .build()) + .addConfigFieldResolver(ConfigFieldResolver.builder() + .location(ConfigFieldResolver.Location.CLIENT) + .target(ConfigFieldResolver.Target.INITIALIZATION) + .resolver(SymbolUtils.createValueSymbolBuilder(SIGNER_RESOLVER_NAME).build()) + .build()) + .build(), + + // TODO this is incorrect for an API client/operation that supports multiple auth schemes. + RuntimeClientPlugin.builder() + .operationPredicate(HttpBearerAuth::hasBearerAuthScheme) + .registerMiddleware(MiddlewareRegistrar.builder() + .resolvedFunction(SymbolUtils.createValueSymbolBuilder( + REGISTER_MIDDLEWARE_NAME).build()) + .useClientOptions() + .build()) + .build() + ); + } + + /** + * Returns if the service has the httpBearerAuth trait. + * + * @param model model definition + * @param service service shape for the API + * @return if the httpBearerAuth trait is used by the service + */ + public static boolean isSupportedAuthentication(Model model, ServiceShape service) { + return ServiceIndex.of(model).getAuthSchemes(service).values().stream().anyMatch(trait -> trait.getClass() + .equals(HttpBearerAuthTrait.class)); + + } + + /** + * Returns if the service and operation support the httpBearerAuthTrait. + * + * @param model model definition + * @param service service shape for the API + * @param operation operation shape + * @return if the service and operation support the httpBearerAuthTrait + */ + public static boolean hasBearerAuthScheme(Model model, ServiceShape service, OperationShape operation) { + Map auth = ServiceIndex.of(model).getEffectiveAuthSchemes(service.getId(), operation.getId()); + return auth.containsKey(HttpBearerAuthTrait.ID) && !operation.hasTrait(OptionalAuthTrait.class); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/knowledge/GoPointableIndex.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/knowledge/GoPointableIndex.java index 83e1fcf6f..d934b89ff 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/knowledge/GoPointableIndex.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/knowledge/GoPointableIndex.java @@ -140,7 +140,7 @@ public static GoPointableIndex of(Model model) { } private boolean isMemberDereferencable(MemberShape member, Shape targetShape) { - return isShapeDereferencable(targetShape) && isMemberPointable(member, targetShape); + return !INHERENTLY_NONDEREFERENCABLE.contains(targetShape.getType()) && isMemberPointable(member, targetShape); } private boolean isMemberNillable(MemberShape member, Shape targetShape) { @@ -148,7 +148,17 @@ private boolean isMemberNillable(MemberShape member, Shape targetShape) { } private boolean isMemberPointable(MemberShape member, Shape targetShape) { - return isShapePointable(targetShape) && nullableIndex.isNullable(member); + + // Streamed blob shapes are never pointers because they are interfaces + if (isBlobStream(targetShape)) { + return false; + } + + if (INHERENTLY_VALUE.contains(targetShape.getType()) || isShapeEnum(targetShape)) { + return false; + } + + return nullableIndex.isMemberNullable(member, NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1_NO_INPUT); } private boolean isShapeDereferencable(Shape shape) { @@ -191,7 +201,9 @@ private boolean isShapePointable(Shape shape) { } private boolean isShapeEnum(Shape shape) { - return shape.getType() == ShapeType.STRING && shape.hasTrait(EnumTrait.class); + return shape.getType() == ShapeType.STRING && shape.hasTrait(EnumTrait.class) + || shape.getType() == ShapeType.ENUM + || shape.getType() == ShapeType.INT_ENUM; } private boolean isBlobStream(Shape shape) { diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/testutils/ExecuteCommand.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/testutils/ExecuteCommand.java new file mode 100644 index 000000000..f04760abe --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/testutils/ExecuteCommand.java @@ -0,0 +1,145 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.testutils; + +import java.io.BufferedReader; +import java.io.File; +import java.io.InputStreamReader; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; +import software.amazon.smithy.utils.SmithyBuilder; + +/** + * Utility for invoking command line utilities. + */ +public final class ExecuteCommand { + private static final Logger LOGGER = Logger.getLogger(ExecuteCommand.class.getName()); + + private final List command; + private final File workingDir; + + private ExecuteCommand(Builder builder) { + command = SmithyBuilder.requiredState("command", builder.command); + workingDir = builder.workingDir; + } + + /** + * Invokes the command in the filepath directory provided. + * @param workingDir Directory to execute the command in. + * @param command Command to be executed. + * @throws Exception if the command fails. + */ + public static void execute(File workingDir, String... command) throws Exception { + ExecuteCommand.builder() + .addCommand(command) + .workingDir(workingDir) + .build() + .execute(); + } + + /** + * Invokes the command returning the exception if there was any. + * @throws Exception if the command fails. + */ + public void execute() throws Exception { + int exitCode; + Process child; + try { + var cmdArray = new String[command.size()]; + command.toArray(cmdArray); + + child = Runtime.getRuntime().exec(cmdArray, null, workingDir); + exitCode = child.waitFor(); + + BufferedReader stdOut = new BufferedReader(new + InputStreamReader(child.getInputStream(), Charset.defaultCharset())); + + BufferedReader stdErr = new BufferedReader(new + InputStreamReader(child.getErrorStream(), Charset.defaultCharset())); + + String s; + while ((s = stdOut.readLine()) != null) { + LOGGER.info(s); + } + stdOut.close(); + while ((s = stdErr.readLine()) != null) { + LOGGER.warning(s); + } + stdErr.close(); + } catch (Exception e) { + throw new Exception("Unable to execute command, " + command, e); + } + + if (exitCode != 0) { + throw new Exception("Command existed with non-zero code, " + command + + ", status code: " + exitCode); + } + } + + /** + * Returns the builder for ExecuteCommand. + * @return ExecuteCommand builder. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * ExecuteCommand builder. + */ + public static final class Builder implements SmithyBuilder { + private List command; + private File workingDir; + + private Builder() { + } + + /** + * Adds command arguments to the set of arguments the command will be executed with. + * @param command command and arguments to be executed. + * @return builder + */ + public Builder addCommand(String... command) { + if (this.command == null) { + this.command = new ArrayList<>(); + } + + this.command.addAll(List.of(command)); + return this; + } + + /** + * Sets the working directory for the command. + * @param workingDir working directory. + * @return builder + */ + public Builder workingDir(File workingDir) { + this.workingDir = workingDir; + return this; + } + + /** + * Builds the ExecuteCommand. + * @return Execute command + */ + @Override + public ExecuteCommand build() { + return new ExecuteCommand(this); + } + } +} diff --git a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration index c9c7fbc6c..f493b3283 100644 --- a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration +++ b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration @@ -1,3 +1,4 @@ +software.amazon.smithy.go.codegen.integration.auth.HttpBearerAuth software.amazon.smithy.go.codegen.integration.ValidationGenerator software.amazon.smithy.go.codegen.integration.IdempotencyTokenMiddlewareGenerator software.amazon.smithy.go.codegen.integration.AddChecksumRequiredMiddleware diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/enum-shape-test/enum-shape-test.smithy b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/enum-shape-test/enum-shape-test.smithy new file mode 100644 index 000000000..47d5bf51c --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/enum-shape-test/enum-shape-test.smithy @@ -0,0 +1,26 @@ +$version: "2.0" + +namespace smithy.example + +service Example { + version: "1.0.0" + operations: [ + ChangeCard + ] +} + +operation ChangeCard { + input: Card + output: Card +} + +structure Card { + suit: Suit +} + +enum Suit { + DIAMOND + CLUB + HEART + SPADE +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/enum-shape-test/expected/types/enums.go b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/enum-shape-test/expected/types/enums.go new file mode 100644 index 000000000..a0d715126 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/enum-shape-test/expected/types/enums.go @@ -0,0 +1,26 @@ +// Code generated by smithy-go-codegen DO NOT EDIT. + + +package types + +type Suit string + +// Enum values for Suit +const ( + SuitDiamond Suit = "DIAMOND" + SuitClub Suit = "CLUB" + SuitHeart Suit = "HEART" + SuitSpade Suit = "SPADE" +) + +// Values returns all known values for Suit. Note that this can be expanded in the +// future, and so it is only as up to date as the client. The ordering of this +// slice is not guaranteed to be stable across updates. +func (Suit) Values() []Suit { + return []Suit{ + "DIAMOND", + "CLUB", + "HEART", + "SPADE", + } +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/changeCardInput.go.struct b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/changeCardInput.go.struct new file mode 100644 index 000000000..7c2b51eab --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/changeCardInput.go.struct @@ -0,0 +1,8 @@ +type ChangeCardInput struct { + + Number types.Number + + Suit types.Suit + + noSmithyDocumentSerde +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/changeCardOutput.go.struct b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/changeCardOutput.go.struct new file mode 100644 index 000000000..de2f82bc1 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/changeCardOutput.go.struct @@ -0,0 +1,11 @@ +type ChangeCardOutput struct { + + Number types.Number + + Suit types.Suit + + // Metadata pertaining to the operation's result. + ResultMetadata middleware.Metadata + + noSmithyDocumentSerde +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/types/enums.go b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/types/enums.go new file mode 100644 index 000000000..038302c35 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/expected/types/enums.go @@ -0,0 +1,66 @@ +// Code generated by smithy-go-codegen DO NOT EDIT. + + +package types + +type Number int32 + +// Enum values for Number +const ( + NumberAce Number = 1 + NumberTwo Number = 2 + NumberThree Number = 3 + NumberFour Number = 4 + NumberFive Number = 5 + NumberSix Number = 6 + NumberSeven Number = 7 + NumberEight Number = 8 + NumberNine Number = 9 + NumberTen Number = 10 + NumberJack Number = 11 + NumberQueen Number = 12 + NumberKing Number = 13 +) + +// Values returns all known values for Number. Note that this can be expanded in +// the future, and so it is only as up to date as the client. The ordering of this +// slice is not guaranteed to be stable across updates. +func (Number) Values() []Number { + return []Number{ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + } +} + +type Suit string + +// Enum values for Suit +const ( + SuitDiamond Suit = "DIAMOND" + SuitClub Suit = "CLUB" + SuitHeart Suit = "HEART" + SuitSpade Suit = "SPADE" +) + +// Values returns all known values for Suit. Note that this can be expanded in the +// future, and so it is only as up to date as the client. The ordering of this +// slice is not guaranteed to be stable across updates. +func (Suit) Values() []Suit { + return []Suit{ + "DIAMOND", + "CLUB", + "HEART", + "SPADE", + } +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/int-enum-shape-test.smithy b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/int-enum-shape-test.smithy new file mode 100644 index 000000000..591923205 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/int-enum-shape-test/int-enum-shape-test.smithy @@ -0,0 +1,43 @@ +$version: "2.0" + +namespace smithy.example + +service Example { + version: "1.0.0" + operations: [ + ChangeCard + ] +} + +operation ChangeCard { + input: Card + output: Card +} + +structure Card { + suit: Suit + number: Number +} + +enum Suit { + DIAMOND + CLUB + HEART + SPADE +} + +intEnum Number { + ACE = 1 + TWO = 2 + THREE = 3 + FOUR = 4 + FIVE = 5 + SIX = 6 + SEVEN = 7 + EIGHT = 8 + NINE = 9 + TEN = 10 + JACK = 11 + QUEEN = 12 + KING = 13 +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/expected/changeCardInput.go.mixin b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/expected/changeCardInput.go.mixin new file mode 100644 index 000000000..f4b675333 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/expected/changeCardInput.go.mixin @@ -0,0 +1,8 @@ +type ChangeCardInput struct { + + Number *int32 + + Suit *string + + noSmithyDocumentSerde +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/expected/changeCardOutput.go.mixin b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/expected/changeCardOutput.go.mixin new file mode 100644 index 000000000..4943789c7 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/expected/changeCardOutput.go.mixin @@ -0,0 +1,11 @@ +type ChangeCardOutput struct { + + Number *int32 + + Suit *string + + // Metadata pertaining to the operation's result. + ResultMetadata middleware.Metadata + + noSmithyDocumentSerde +} diff --git a/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/mixin-test.smithy b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/mixin-test.smithy new file mode 100644 index 000000000..000e75a2c --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/resources/software/amazon/smithy/go/codegen/smithy-tests/mixin-test/mixin-test.smithy @@ -0,0 +1,23 @@ +$version: "2.0" + +namespace smithy.example + +service Example { + version: "1.0.0" + operations: [ + ChangeCard + ] +} + +operation ChangeCard { + input: Card + output: Card +} + +@mixin +structure CardValuesMixin { + suit: String + number: Integer +} + +structure Card with [CardValuesMixin] {} diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/EnumShapeGeneratorTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/EnumShapeGeneratorTest.java new file mode 100644 index 000000000..cb1ddb3cb --- /dev/null +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/EnumShapeGeneratorTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen; + +import java.util.logging.Logger; +import org.junit.jupiter.api.Test; + +import software.amazon.smithy.build.MockManifest; +import software.amazon.smithy.build.PluginContext; +import software.amazon.smithy.go.codegen.GoCodegenPlugin; +import software.amazon.smithy.model.Model; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +import static software.amazon.smithy.go.codegen.TestUtils.buildMockPluginContext; +import static software.amazon.smithy.go.codegen.TestUtils.loadSmithyModelFromResource; +import static software.amazon.smithy.go.codegen.TestUtils.loadExpectedFileStringFromResource; + + +public class EnumShapeGeneratorTest { + private static final Logger LOGGER = Logger.getLogger(EnumShapeGeneratorTest.class.getName()); + + @Test + public void testEnumShapeTest() { + + // Arrange + Model model = + loadSmithyModelFromResource("enum-shape-test"); + MockManifest manifest = + new MockManifest(); + PluginContext context = + buildMockPluginContext(model, manifest, "smithy.example#Example"); + + // Act + (new GoCodegenPlugin()).execute(context); + + // Assert + String actualSuitEnumShapeCode = + manifest.getFileString("types/enums.go").get(); + String expectedSuitEnumShapeCode = + loadExpectedFileStringFromResource("enum-shape-test", "types/enums.go"); + assertThat("enum shape actual generated code is equal to the expected generated code", + actualSuitEnumShapeCode, + is(expectedSuitEnumShapeCode)); + + } + +} diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GenerateStandaloneGoModuleTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GenerateStandaloneGoModuleTest.java new file mode 100644 index 000000000..0cacc390c --- /dev/null +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GenerateStandaloneGoModuleTest.java @@ -0,0 +1,116 @@ +package software.amazon.smithy.go.codegen; + +import static software.amazon.smithy.go.codegen.GoWriter.goBlockTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.TestUtils.hasGoInstalled; +import static software.amazon.smithy.go.codegen.TestUtils.makeGoModule; +import static software.amazon.smithy.go.codegen.TestUtils.testGoModule; + +import java.nio.file.Path; +import java.util.Map; +import java.util.logging.Logger; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.build.FileManifest; +import software.amazon.smithy.utils.MapUtils; + +public class GenerateStandaloneGoModuleTest { + private static final Logger LOGGER = Logger.getLogger(GenerateStandaloneGoModuleTest.class.getName()); + + @Test + public void testGenerateGoModule() throws Exception { + if (!hasGoInstalled()) { + LOGGER.warning("Skipping testGenerateGoModule, go command cannot be executed."); + return; + } + + var testPath = getTestOutputDir(); + LOGGER.warning("generating test suites into " + testPath); + + var fileManifest = FileManifest.create(testPath); + var writers = new GoWriterDelegator(fileManifest); + + writers.useFileWriter("test-directory/package-name/gofile.go", + "github.com/aws/smithy-go/internal/testmodule/packagename", + (w) -> { + w.writeGoTemplate(""" + type $name:L struct { + Bar $barType:T + Baz *$bazType:T + } + + $somethingElse:W + """, + MapUtils.of( + "name", "Foo", + "barType", SymbolUtils.createValueSymbolBuilder("string").build(), + "bazType", SymbolUtils.createValueSymbolBuilder("Request", + SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build(), + "somethingElse", generateSomethingElse() + )); + }); + + writers.useFileWriter("test-directory/package-name/gofile_test.go", + "github.com/aws/smithy-go/internal/testmodule/packagename", + (w) -> { + Map commonArgs = MapUtils.of( + "testingT", SymbolUtils.createValueSymbolBuilder("T", SmithyGoDependency.TESTING).build() + ); + + w.writeGoTemplate(""" + func Test$name:L(t *$testingT:T) { + v := $name:L{} + v.Baz = nil + } + """, + commonArgs, + MapUtils.of( + "name", "Foo" + )); + w.writeGoBlockTemplate("func TestBar(t *$testingT:T) {", "}", + commonArgs, + (ww) -> { + ww.write("t.Skip(\"not relevant\")"); + + }); + }); + + var dependencies = writers.getDependencies(); + writers.flushWriters(); + + ManifestWriter.builder() + .moduleName("github.com/aws/smithy-go/internal/testmodule") + .fileManifest(fileManifest) + .dependencies(dependencies) + .build() + .writeManifest(); + + makeGoModule(testPath); + testGoModule(testPath); + } + + private GoWriter.Writable generateSomethingElse() { + return goBlockTemplate("func (s *$name:L) $funcName:L(i int) string {", "}", + MapUtils.of("funcName", "SomethingElse"), + MapUtils.of( + "name", "Foo" + ), + (w) -> { + w.write("return \"hello!\""); + }); + } + + private static Path getTestOutputDir() { + var testWorkspace = System.getenv("SMITHY_GO_TEST_WORKSPACE"); + if (testWorkspace != null) { + return Path.of(testWorkspace).toAbsolutePath(); + } + + return Path.of(System.getProperty("user.dir")) + .resolve("build") + .resolve("test-generated") + .resolve("go") + .resolve("internal") + .resolve("testmodule") + .toAbsolutePath(); + } +} diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/IntEnumShapeGeneratorTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/IntEnumShapeGeneratorTest.java new file mode 100644 index 000000000..89e79e4d6 --- /dev/null +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/IntEnumShapeGeneratorTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen; + +import java.util.logging.Logger; +import org.junit.jupiter.api.Test; + +import software.amazon.smithy.build.MockManifest; +import software.amazon.smithy.build.PluginContext; +import software.amazon.smithy.go.codegen.GoCodegenPlugin; +import software.amazon.smithy.model.Model; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +import static software.amazon.smithy.go.codegen.TestUtils.buildMockPluginContext; +import static software.amazon.smithy.go.codegen.TestUtils.loadSmithyModelFromResource; +import static software.amazon.smithy.go.codegen.TestUtils.loadExpectedFileStringFromResource; + + +public class IntEnumShapeGeneratorTest { + private static final Logger LOGGER = Logger.getLogger(IntEnumShapeGeneratorTest.class.getName()); + + @Test + public void testIntEnumShapeTest() { + + // Arrange + Model model = + loadSmithyModelFromResource("int-enum-shape-test"); + MockManifest manifest = + new MockManifest(); + PluginContext context = + buildMockPluginContext(model, manifest, "smithy.example#Example"); + + // Act + (new GoCodegenPlugin()).execute(context); + + // Assert + String actualEnumShapeCode = + manifest.getFileString("types/enums.go").get(); + String expectedEnumShapeCode = + loadExpectedFileStringFromResource("int-enum-shape-test", "types/enums.go"); + assertThat("intEnum shape actual generated code is equal to the expected generated code", + actualEnumShapeCode, + is(expectedEnumShapeCode)); + String actualChangeCardOperationCode = + manifest.getFileString("api_op_ChangeCard.go").get(); + String expectedChangeCardInputCode = + loadExpectedFileStringFromResource("int-enum-shape-test", "changeCardInput.go.struct"); + assertThat("intEnum shape properly referenced in generated input structure code", + actualChangeCardOperationCode, + containsString(expectedChangeCardInputCode)); + String expectedChangeCardOutputCode = + loadExpectedFileStringFromResource("int-enum-shape-test", "changeCardOutput.go.struct"); + assertThat("intEnum shape properly referenced in generated output structure code", + actualChangeCardOperationCode, + containsString(expectedChangeCardOutputCode)); + + } + +} diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/MixinCodegenTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/MixinCodegenTest.java new file mode 100644 index 000000000..9e2867cd8 --- /dev/null +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/MixinCodegenTest.java @@ -0,0 +1,67 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen; + +import java.util.logging.Logger; +import org.junit.jupiter.api.Test; + +import software.amazon.smithy.build.MockManifest; +import software.amazon.smithy.build.PluginContext; +import software.amazon.smithy.go.codegen.GoCodegenPlugin; +import software.amazon.smithy.model.Model; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.MatcherAssert.assertThat; + +import static software.amazon.smithy.go.codegen.TestUtils.buildMockPluginContext; +import static software.amazon.smithy.go.codegen.TestUtils.loadSmithyModelFromResource; +import static software.amazon.smithy.go.codegen.TestUtils.loadExpectedFileStringFromResource; + + +public class MixinCodegenTest { + private static final Logger LOGGER = Logger.getLogger(MixinCodegenTest.class.getName()); + + @Test + public void testMixinCodegen() { + + // Arrange + Model model = + loadSmithyModelFromResource("mixin-test"); + MockManifest manifest = + new MockManifest(); + PluginContext context = + buildMockPluginContext(model, manifest, "smithy.example#Example"); + + // Act + (new GoCodegenPlugin()).execute(context); + + // Assert + String actualChangeCardOperationCode = + manifest.getFileString("api_op_ChangeCard.go").get(); + String expectedInputMixinCode = + loadExpectedFileStringFromResource("mixin-test", "changeCardInput.go.mixin"); + assertThat("mixins are properly applied in the input structure", + actualChangeCardOperationCode, + containsString(expectedInputMixinCode)); + String expectedOutputMixinCode = + loadExpectedFileStringFromResource("mixin-test", "changeCardOutput.go.mixin"); + assertThat("mixins are properly applied in the output structure", + actualChangeCardOperationCode, + containsString(expectedOutputMixinCode)); + + } + +} diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/TestUtils.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/TestUtils.java new file mode 100644 index 000000000..dcf521121 --- /dev/null +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/TestUtils.java @@ -0,0 +1,160 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen; + +import static software.amazon.smithy.go.codegen.testutils.ExecuteCommand.execute; + +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import software.amazon.smithy.build.FileManifest; +import software.amazon.smithy.build.PluginContext; +import software.amazon.smithy.go.codegen.testutils.ExecuteCommand; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.node.ObjectNode; + + +public class TestUtils { + + public static final String SMITHY_TESTS_PREFIX = "smithy-tests"; + public static final String SMITHY_TESTS_EXPECTED_PREFIX = "expected"; + + public static Model loadSmithyModelFromResource(String testPath) { + String resourcePath = + SMITHY_TESTS_PREFIX + "/" + testPath + "/" + testPath + ".smithy"; + return Model.assembler() + .addImport(TestUtils.class.getResource(resourcePath)) + .discoverModels() + .assemble() + .unwrap(); + } + + public static String loadExpectedFileStringFromResource(String testPath, String filePath) { + String resourcePath = + SMITHY_TESTS_PREFIX + "/" + testPath + "/" + SMITHY_TESTS_EXPECTED_PREFIX + "/" + filePath; + return getResourceAsString(resourcePath); + } + + public static String getResourceAsString(String resourcePath) { + try { + return Files.readString( + Paths.get(TestUtils.class.getResource(resourcePath).toURI()), + Charset.forName("utf-8")); + } catch (Exception e) { + return null; + } + } + + public static PluginContext buildMockPluginContext( + Model model, + FileManifest manifest, + String serviceShapeId + ) { + return buildPluginContext( + model, + manifest, + serviceShapeId, + "example", + "0.0.1"); + } + + public static PluginContext buildPluginContext( + Model model, + FileManifest manifest, + String serviceShapeId, + String moduleName, + String moduleVersion + ) { + return PluginContext.builder() + .model(model) + .fileManifest(manifest) + .settings(getSettingsNode( + serviceShapeId, + moduleName, + moduleVersion, + "Example")) + .build(); + } + + public static ObjectNode getSettingsNode( + String serviceShapeId, + String moduleName, + String moduleVersion, + String sdkId + ) { + return Node.objectNodeBuilder() + .withMember("service", Node.from(serviceShapeId)) + .withMember("module", Node.from(moduleName)) + .withMember("moduleVersion", Node.from(moduleVersion)) + .withMember("homepage", Node.from("https://docs.amplify.aws/")) + .withMember("sdkId", Node.from(sdkId)) + .withMember("author", Node.from("Amazon Web Services")) + .withMember("gitRepo", Node.from("https://github.com/aws-amplify/amplify-codegen.git")) + .withMember("swiftVersion", Node.from("5.5.0")) + .build(); + } + + /** + * Returns the path for the repository's root relative to the smithy-go-codegen module. + * @return repo root. + */ + public static Path getRepoRootDir() { + return Path.of(System.getProperty("user.dir")) + .resolve("..") + .resolve("..") + .toAbsolutePath(); + } + + /** + * Returns true if the go command can be executed. + * @return go command can be executed. + */ + public static boolean hasGoInstalled() { + try{ + ExecuteCommand.builder() + .addCommand("go", "version") + .build() + .execute(); + } catch (Exception e) { + return false; + } + + return true; + } + + public static void makeGoModule(Path path) throws Exception { + var repoRoot = getRepoRootDir(); + execute(repoRoot.toFile(), + "go", "run", + "github.com/awslabs/aws-go-multi-module-repository-tools/cmd/gomodgen@latest", + "--build", path.resolve("..").toAbsolutePath().toString(), + "--plugin-dir", path.getFileName().toString(), + "--copy-artifact=false", + "--prepare-target-dir=false" + ); + + execute(path.toFile(), "go", "mod", "edit", "--replace", "github.com/aws/smithy-go="+repoRoot); + execute(path.toFile(), "go", "mod", "tidy"); + execute(path.toFile(), "gofmt", "-w", "-s", "."); + } + + public static void testGoModule(Path path) throws Exception { + execute(path.toFile(), "go", "test", "-v", "./..."); + } +} \ No newline at end of file diff --git a/context/suppress_expired.go b/context/suppress_expired.go new file mode 100644 index 000000000..a39b84a27 --- /dev/null +++ b/context/suppress_expired.go @@ -0,0 +1,81 @@ +package context + +import "context" + +// valueOnlyContext provides a utility to preserve only the values of a +// Context. Suppressing any cancellation or deadline on that context being +// propagated downstream of this value. +// +// If preserveExpiredValues is false (default), and the valueCtx is canceled, +// calls to lookup values with the Values method, will always return nil. Setting +// preserveExpiredValues to true, will allow the valueOnlyContext to lookup +// values in valueCtx even if valueCtx is canceled. +// +// Based on the Go standard libraries net/lookup.go onlyValuesCtx utility. +// https://github.com/golang/go/blob/da2773fe3e2f6106634673a38dc3a6eb875fe7d8/src/net/lookup.go +type valueOnlyContext struct { + context.Context + + preserveExpiredValues bool + valuesCtx context.Context +} + +var _ context.Context = (*valueOnlyContext)(nil) + +// Value looks up the key, returning its value. If configured to not preserve +// values of expired context, and the wrapping context is canceled, nil will be +// returned. +func (v *valueOnlyContext) Value(key interface{}) interface{} { + if !v.preserveExpiredValues { + select { + case <-v.valuesCtx.Done(): + return nil + default: + } + } + + return v.valuesCtx.Value(key) +} + +// WithSuppressCancel wraps the Context value, suppressing its deadline and +// cancellation events being propagated downstream to consumer of the returned +// context. +// +// By default the wrapped Context's Values are available downstream until the +// wrapped Context is canceled. Once the wrapped Context is canceled, Values +// method called on the context return will no longer lookup any key. As they +// are now considered expired. +// +// To override this behavior, use WithPreserveExpiredValues on the Context +// before it is wrapped by WithSuppressCancel. This will make the Context +// returned by WithSuppressCancel allow lookup of expired values. +func WithSuppressCancel(ctx context.Context) context.Context { + return &valueOnlyContext{ + Context: context.Background(), + valuesCtx: ctx, + + preserveExpiredValues: GetPreserveExpiredValues(ctx), + } +} + +type preserveExpiredValuesKey struct{} + +// WithPreserveExpiredValues adds a Value to the Context if expired values +// should be preserved, and looked up by a Context wrapped by +// WithSuppressCancel. +// +// WithPreserveExpiredValues must be added as a value to a Context, before that +// Context is wrapped by WithSuppressCancel +func WithPreserveExpiredValues(ctx context.Context, enable bool) context.Context { + return context.WithValue(ctx, preserveExpiredValuesKey{}, enable) +} + +// GetPreserveExpiredValues looks up, and returns the PreserveExpressValues +// value in the context. Returning true if enabled, false otherwise. +func GetPreserveExpiredValues(ctx context.Context) bool { + v := ctx.Value(preserveExpiredValuesKey{}) + if v != nil { + return v.(bool) + } + return false +} diff --git a/encoding/httpbinding/uri.go b/encoding/httpbinding/uri.go index 64e40121e..f04e11984 100644 --- a/encoding/httpbinding/uri.go +++ b/encoding/httpbinding/uri.go @@ -20,6 +20,9 @@ func newURIValue(path *[]byte, rawPath *[]byte, buffer *[]byte, key string) URIV func (u URIValue) modifyURI(value string) (err error) { *u.path, *u.buffer, err = replacePathElement(*u.path, *u.buffer, u.key, value, false) + if err != nil { + return err + } *u.rawPath, *u.buffer, err = replacePathElement(*u.rawPath, *u.buffer, u.key, value, true) return err } diff --git a/encoding/json/escape_test.go b/encoding/json/escape_test.go new file mode 100644 index 000000000..c3a07a126 --- /dev/null +++ b/encoding/json/escape_test.go @@ -0,0 +1,49 @@ +package json + +import ( + "bytes" + "testing" +) + +func TestEscapeStringBytes(t *testing.T) { + cases := map[string]struct { + expected string + input []byte + }{ + "safeSet only": { + expected: `"mountainPotato"`, + input: []byte("mountainPotato"), + }, + "parenthesis": { + expected: `"foo\""`, + input: []byte(`foo"`), + }, + "double escape": { + expected: `"hello\\\\world"`, + input: []byte(`hello\\world`), + }, + "new line": { + expected: `"foo\nbar"`, + input: []byte("foo\nbar"), + }, + "carriage return": { + expected: `"foo\rbar"`, + input: []byte("foo\rbar"), + }, + "tab": { + expected: `"foo\tbar"`, + input: []byte("foo\tbar"), + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var buffer bytes.Buffer + escapeStringBytes(&buffer, c.input) + expected := c.expected + actual := buffer.String() + if expected != actual { + t.Errorf("\nexpected %v \nactual %v", expected, actual) + } + }) + } +} diff --git a/encoding/json/object.go b/encoding/json/object.go index 15fb6478c..722346d03 100644 --- a/encoding/json/object.go +++ b/encoding/json/object.go @@ -17,9 +17,7 @@ func newObject(w *bytes.Buffer, scratch *[]byte) *Object { } func (o *Object) writeKey(key string) { - o.w.WriteRune(quote) - o.w.Write([]byte(key)) - o.w.WriteRune(quote) + escapeStringBytes(o.w, []byte(key)) o.w.WriteRune(colon) } diff --git a/encoding/json/object_test.go b/encoding/json/object_test.go index 5f09fb1ae..1e0830e05 100644 --- a/encoding/json/object_test.go +++ b/encoding/json/object_test.go @@ -7,9 +7,9 @@ import ( func TestObject(t *testing.T) { buffer := bytes.NewBuffer(nil) - scatch := make([]byte, 64) + scratch := make([]byte, 64) - object := newObject(buffer, &scatch) + object := newObject(buffer, &scratch) object.Key("foo").String("bar") object.Key("faz").String("baz") object.Close() @@ -19,3 +19,16 @@ func TestObject(t *testing.T) { t.Errorf("expected %+q, but got %+q", e, a) } } + +func TestObjectKey_escaped(t *testing.T) { + jsonEncoder := NewEncoder() + object := jsonEncoder.Object() + object.Key("foo\"").String("bar") + object.Key("faz").String("baz") + object.Close() + + e := []byte(`{"foo\"":"bar","faz":"baz"}`) + if a := object.w.Bytes(); bytes.Compare(e, a) != 0 { + t.Errorf("expected %+q, but got %+q", e, a) + } +} diff --git a/go.mod b/go.mod index 67d6704b1..d163d76ea 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/aws/smithy-go go 1.15 require ( - github.com/google/go-cmp v0.5.7 + github.com/google/go-cmp v0.5.8 github.com/jmespath/go-jmespath v0.4.0 ) diff --git a/go.sum b/go.sum index d51e68e78..b03fbb7a3 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= -github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -9,8 +9,6 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= diff --git a/go_module_metadata.go b/go_module_metadata.go index 8bf584b66..08db245f8 100644 --- a/go_module_metadata.go +++ b/go_module_metadata.go @@ -3,4 +3,4 @@ package smithy // goModuleVersion is the tagged release for this module -const goModuleVersion = "1.11.0" +const goModuleVersion = "1.13.3" diff --git a/internal/sync/singleflight/LICENSE b/internal/sync/singleflight/LICENSE new file mode 100644 index 000000000..fe6a62006 --- /dev/null +++ b/internal/sync/singleflight/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/internal/sync/singleflight/docs.go b/internal/sync/singleflight/docs.go new file mode 100644 index 000000000..9c9d02b94 --- /dev/null +++ b/internal/sync/singleflight/docs.go @@ -0,0 +1,8 @@ +// Package singleflight provides a duplicate function call suppression +// mechanism. This package is a fork of the Go golang.org/x/sync/singleflight +// package. The package is forked, because the package a part of the unstable +// and unversioned golang.org/x/sync module. +// +// https://github.com/golang/sync/tree/67f06af15bc961c363a7260195bcd53487529a21/singleflight + +package singleflight diff --git a/internal/sync/singleflight/singleflight.go b/internal/sync/singleflight/singleflight.go new file mode 100644 index 000000000..e8a1b17d5 --- /dev/null +++ b/internal/sync/singleflight/singleflight.go @@ -0,0 +1,210 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v interface{}) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val interface{} + err error + + // forgotten indicates whether Forget was called with this call's key + // while the call was still in flight. + forgotten bool + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result struct { + Val interface{} + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result { + ch := make(chan Result, 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call{chans: []chan<- Result{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + c.wg.Done() + g.mu.Lock() + defer g.mu.Unlock() + if !c.forgotten { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group) Forget(key string) { + g.mu.Lock() + if c, ok := g.m[key]; ok { + c.forgotten = true + } + delete(g.m, key) + g.mu.Unlock() +} diff --git a/internal/sync/singleflight/singleflight_test.go b/internal/sync/singleflight/singleflight_test.go new file mode 100644 index 000000000..3e51203bd --- /dev/null +++ b/internal/sync/singleflight/singleflight_test.go @@ -0,0 +1,320 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "runtime/debug" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var g Group + v, err, _ := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + var g Group + someErr := errors.New("Some error") + v, err, _ := g.Do("key", func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + var g Group + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (interface{}, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err, _ := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if s, _ := v.(string); s != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +// Test that singleflight behaves correctly after Forget called. +// See https://github.com/golang/go/issues/31420 +func TestForget(t *testing.T) { + var g Group + + var ( + firstStarted = make(chan struct{}) + unblockFirst = make(chan struct{}) + firstFinished = make(chan struct{}) + ) + + go func() { + g.Do("key", func() (i interface{}, e error) { + close(firstStarted) + <-unblockFirst + close(firstFinished) + return + }) + }() + <-firstStarted + g.Forget("key") + + unblockSecond := make(chan struct{}) + secondResult := g.DoChan("key", func() (i interface{}, e error) { + <-unblockSecond + return 2, nil + }) + + close(unblockFirst) + <-firstFinished + + thirdResult := g.DoChan("key", func() (i interface{}, e error) { + return 3, nil + }) + + close(unblockSecond) + <-secondResult + r := <-thirdResult + if r.Val != 2 { + t.Errorf("We should receive result produced by second call, expected: 2, got %d", r.Val) + } +} + +func TestDoChan(t *testing.T) { + var g Group + ch := g.DoChan("key", func() (interface{}, error) { + return "bar", nil + }) + + res := <-ch + v := res.Val + err := res.Err + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +// Test singleflight behaves correctly after Do panic. +// See https://github.com/golang/go/issues/41133 +func TestPanicDo(t *testing.T) { + var g Group + fn := func() (interface{}, error) { + panic("invalid memory address or nil pointer dereference") + } + + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + t.Logf("Got panic: %v\n%s", err, debug.Stack()) + atomic.AddInt32(&panicCount, 1) + } + + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + + g.Do("key", fn) + }() + } + + select { + case <-done: + if panicCount != n { + t.Errorf("Expect %d panic, but got %d", n, panicCount) + } + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestGoexitDo(t *testing.T) { + var g Group + fn := func() (interface{}, error) { + runtime.Goexit() + return nil, nil + } + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, err, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestPanicDoChan(t *testing.T) { + if runtime.GOOS == "js" { + t.Skipf("js does not support exec") + } + + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + defer func() { + recover() + }() + + g := new(Group) + ch := g.DoChan("", func() (interface{}, error) { + panic("Panicking in DoChan") + }) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in DoChan")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in DoChan") + } +} + +func TestPanicDoSharedByDoChan(t *testing.T) { + if runtime.GOOS == "js" { + t.Skipf("js does not support exec") + } + + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + blocked := make(chan struct{}) + unblock := make(chan struct{}) + + g := new(Group) + go func() { + defer func() { + recover() + }() + g.Do("", func() (interface{}, error) { + close(blocked) + <-unblock + panic("Panicking in Do") + }) + }() + + <-blocked + ch := g.DoChan("", func() (interface{}, error) { + panic("DoChan unexpectedly executed callback") + }) + close(unblock) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in Do shared by DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") + } +} diff --git a/modman.toml b/modman.toml new file mode 100644 index 000000000..20295cdd2 --- /dev/null +++ b/modman.toml @@ -0,0 +1,11 @@ +[dependencies] + "github.com/google/go-cmp" = "v0.5.8" + "github.com/jmespath/go-jmespath" = "v0.4.0" + +[modules] + + [modules.codegen] + no_tag = true + + [modules."codegen/smithy-go-codegen/build/test-generated/go/internal/testmodule"] + no_tag = true diff --git a/transport/http/checksum_middleware.go b/transport/http/checksum_middleware.go index 2ec7cbaee..bc4ad6e79 100644 --- a/transport/http/checksum_middleware.go +++ b/transport/http/checksum_middleware.go @@ -45,6 +45,11 @@ func (m *contentMD5Checksum) HandleBuild( stream := req.GetStream() // compute checksum if payload is explicit if stream != nil { + if !req.IsStreamSeekable() { + return out, metadata, fmt.Errorf( + "unseekable stream is not supported for computing md5 checksum") + } + v, err := computeMD5Checksum(stream) if err != nil { return out, metadata, fmt.Errorf("error computing md5 checksum, %w", err) diff --git a/transport/http/checksum_middleware_test.go b/transport/http/checksum_middleware_test.go index 6a6dfc89b..a911d68d9 100644 --- a/transport/http/checksum_middleware_test.go +++ b/transport/http/checksum_middleware_test.go @@ -35,7 +35,7 @@ func TestChecksumMiddleware(t *testing.T) { "nil body": {}, "unseekable payload": { payload: bytes.NewBuffer([]byte(`xyz`)), - expectError: "error rewinding request stream", + expectError: "unseekable stream is not supported", }, } @@ -61,6 +61,7 @@ func TestChecksumMiddleware(t *testing.T) { if e, a := c.expectError, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect error to contain %q, got %v", e, a) } + return } else if err != nil { t.Fatalf("expect no error, got %v", err) } diff --git a/transport/http/middleware_content_length.go b/transport/http/middleware_content_length.go index fa2c82755..9969389bb 100644 --- a/transport/http/middleware_content_length.go +++ b/transport/http/middleware_content_length.go @@ -44,12 +44,6 @@ func (m *ComputeContentLength) HandleBuild( "failed getting length of request stream, %w", err) } else if ok { req.ContentLength = n - if n == 0 { - // If the content length could be determined, and the body is empty - // the stream must be cleared to prevent unexpected chunk encoding. - req, _ = req.SetStream(nil) - in.Request = req - } } return next.HandleBuild(ctx, in) diff --git a/transport/http/middleware_content_length_test.go b/transport/http/middleware_content_length_test.go index cd27849d1..16a1f265c 100644 --- a/transport/http/middleware_content_length_test.go +++ b/transport/http/middleware_content_length_test.go @@ -13,38 +13,51 @@ import ( func TestContentLengthMiddleware(t *testing.T) { cases := map[string]struct { - Stream io.Reader - ExpectLen int64 - ExpectErr string + Stream io.Reader + ExpectNilStream bool + ExpectLen int64 + ExpectErr string }{ // Cases "bytes.Reader": { - Stream: bytes.NewReader(make([]byte, 10)), - ExpectLen: 10, + Stream: bytes.NewReader(make([]byte, 10)), + ExpectLen: 10, + ExpectNilStream: false, }, "bytes.Buffer": { - Stream: bytes.NewBuffer(make([]byte, 10)), - ExpectLen: 10, + Stream: bytes.NewBuffer(make([]byte, 10)), + ExpectLen: 10, + ExpectNilStream: false, }, "strings.Reader": { - Stream: strings.NewReader("hello"), - ExpectLen: 5, + Stream: strings.NewReader("hello"), + ExpectLen: 5, + ExpectNilStream: false, }, "empty stream": { - Stream: strings.NewReader(""), - ExpectLen: 0, + Stream: strings.NewReader(""), + ExpectLen: 0, + ExpectNilStream: false, + }, + "empty stream bytes": { + Stream: bytes.NewReader([]byte{}), + ExpectLen: 0, + ExpectNilStream: false, }, "nil stream": { - ExpectLen: 0, + ExpectLen: 0, + ExpectNilStream: true, }, "un-seekable and no length": { - Stream: &basicReader{buf: make([]byte, 10)}, - ExpectLen: -1, + Stream: &basicReader{buf: make([]byte, 10)}, + ExpectLen: -1, + ExpectNilStream: false, }, "with error": { - Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, - ExpectErr: "seek failed", - ExpectLen: -1, + Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, + ExpectErr: "seek failed", + ExpectLen: -1, + ExpectNilStream: false, }, } @@ -57,10 +70,15 @@ func TestContentLengthMiddleware(t *testing.T) { t.Fatalf("expect to set stream, %v", err) } + var updatedRequest *Request var m ComputeContentLength _, _, err = m.HandleBuild(context.Background(), middleware.BuildInput{Request: req}, - nopBuildHandler, + middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error) { + updatedRequest = input.Request.(*Request) + return out, metadata, nil + }), ) if len(c.ExpectErr) != 0 { if err == nil { @@ -69,13 +87,18 @@ func TestContentLengthMiddleware(t *testing.T) { if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect error to contain %q, got %v", e, a) } + return } else if err != nil { t.Fatalf("expect no error, got %v", err) } - if e, a := c.ExpectLen, req.ContentLength; e != a { + if e, a := c.ExpectLen, updatedRequest.ContentLength; e != a { t.Errorf("expect %v content-length, got %v", e, a) } + + if e, a := c.ExpectNilStream, updatedRequest.stream == nil; e != a { + t.Errorf("expect %v nil stream, got %v", e, a) + } }) } } diff --git a/transport/http/middleware_headers.go b/transport/http/middleware_headers.go index 49884e6af..eac32b4ba 100644 --- a/transport/http/middleware_headers.go +++ b/transport/http/middleware_headers.go @@ -7,6 +7,85 @@ import ( "github.com/aws/smithy-go/middleware" ) +type isContentTypeAutoSet struct{} + +// SetIsContentTypeDefaultValue returns a Context specifying if the request's +// content-type header was set to a default value. +func SetIsContentTypeDefaultValue(ctx context.Context, isDefault bool) context.Context { + return context.WithValue(ctx, isContentTypeAutoSet{}, isDefault) +} + +// GetIsContentTypeDefaultValue returns if the content-type HTTP header on the +// request is a default value that was auto assigned by an operation +// serializer. Allows middleware post serialization to know if the content-type +// was auto set to a default value or not. +// +// Also returns false if the Context value was never updated to include if +// content-type was set to a default value. +func GetIsContentTypeDefaultValue(ctx context.Context) bool { + v, _ := ctx.Value(isContentTypeAutoSet{}).(bool) + return v +} + +// AddNoPayloadDefaultContentTypeRemover Adds the DefaultContentTypeRemover +// middleware to the stack after the operation serializer. This middleware will +// remove the content-type header from the request if it was set as a default +// value, and no request payload is present. +// +// Returns error if unable to add the middleware. +func AddNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) { + err = stack.Serialize.Insert(removeDefaultContentType{}, + "OperationSerializer", middleware.After) + if err != nil { + return fmt.Errorf("failed to add %s serialize middleware, %w", + removeDefaultContentType{}.ID(), err) + } + + return nil +} + +// RemoveNoPayloadDefaultContentTypeRemover removes the +// DefaultContentTypeRemover middleware from the stack. Returns an error if +// unable to remove the middleware. +func RemoveNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) { + _, err = stack.Serialize.Remove(removeDefaultContentType{}.ID()) + if err != nil { + return fmt.Errorf("failed to remove %s serialize middleware, %w", + removeDefaultContentType{}.ID(), err) + + } + return nil +} + +// removeDefaultContentType provides after serialization middleware that will +// remove the content-type header from an HTTP request if the header was set as +// a default value by the operation serializer, and there is no request payload. +type removeDefaultContentType struct{} + +// ID returns the middleware ID +func (removeDefaultContentType) ID() string { return "RemoveDefaultContentType" } + +// HandleSerialize implements the serialization middleware. +func (removeDefaultContentType) HandleSerialize( + ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler, +) ( + out middleware.SerializeOutput, meta middleware.Metadata, err error, +) { + req, ok := input.Request.(*Request) + if !ok { + return out, meta, fmt.Errorf( + "unexpected request type %T for removeDefaultContentType middleware", + input.Request) + } + + if GetIsContentTypeDefaultValue(ctx) && req.GetStream() == nil { + req.Header.Del("Content-Type") + input.Request = req + } + + return next.HandleSerialize(ctx, input) +} + type headerValue struct { header string value string diff --git a/transport/http/request.go b/transport/http/request.go index 5796a689c..7177d6f95 100644 --- a/transport/http/request.go +++ b/transport/http/request.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net/http" "net/url" + "strings" iointernal "github.com/aws/smithy-go/transport/http/internal/io" ) @@ -33,6 +34,14 @@ func NewStackRequest() interface{} { } } +// IsHTTPS returns if the request is HTTPS. Returns false if no endpoint URL is set. +func (r *Request) IsHTTPS() bool { + if r.URL == nil { + return false + } + return strings.EqualFold(r.URL.Scheme, "https") +} + // Clone returns a deep copy of the Request for the new context. A reference to // the Stream is copied, but the underlying stream is not copied. func (r *Request) Clone() *Request { @@ -45,19 +54,23 @@ func (r *Request) Clone() *Request { // to the request and ok set. If the length cannot be determined, an error will // be returned. func (r *Request) StreamLength() (size int64, ok bool, err error) { - if r.stream == nil { + return streamLength(r.stream, r.isStreamSeekable, r.streamStartPos) +} + +func streamLength(stream io.Reader, seekable bool, startPos int64) (size int64, ok bool, err error) { + if stream == nil { return 0, true, nil } - if l, ok := r.stream.(interface{ Len() int }); ok { + if l, ok := stream.(interface{ Len() int }); ok { return int64(l.Len()), true, nil } - if !r.isStreamSeekable { + if !seekable { return 0, false, nil } - s := r.stream.(io.Seeker) + s := stream.(io.Seeker) endOffset, err := s.Seek(0, io.SeekEnd) if err != nil { return 0, false, err @@ -69,12 +82,12 @@ func (r *Request) StreamLength() (size int64, ok bool, err error) { // file, and wants to skip the first N bytes uploading the rest. The // application would move the file's offset N bytes, then hand it off to // the SDK to send the remaining. The SDK should respect that initial offset. - _, err = s.Seek(r.streamStartPos, io.SeekStart) + _, err = s.Seek(startPos, io.SeekStart) if err != nil { return 0, false, err } - return endOffset - r.streamStartPos, true, nil + return endOffset - startPos, true, nil } // RewindStream will rewind the io.Reader to the relative start position if it @@ -103,23 +116,41 @@ func (r *Request) IsStreamSeekable() bool { return r.isStreamSeekable } -// SetStream returns a clone of the request with the stream set to the provided reader. -// May return an error if the provided reader is seekable but returns an error. +// SetStream returns a clone of the request with the stream set to the provided +// reader. May return an error if the provided reader is seekable but returns +// an error. func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) { rc = r.Clone() + if reader == http.NoBody { + reader = nil + } + + var isStreamSeekable bool + var streamStartPos int64 switch v := reader.(type) { case io.Seeker: n, err := v.Seek(0, io.SeekCurrent) if err != nil { return r, err } - rc.isStreamSeekable = true - rc.streamStartPos = n + isStreamSeekable = true + streamStartPos = n default: - rc.isStreamSeekable = false + // If the stream length can be determined, and is determined to be empty, + // use a nil stream to prevent confusion between empty vs not-empty + // streams. + length, ok, err := streamLength(reader, false, 0) + if err != nil { + return nil, err + } else if ok && length == 0 { + reader = nil + } } + rc.stream = reader + rc.isStreamSeekable = isStreamSeekable + rc.streamStartPos = streamStartPos return rc, err } @@ -139,7 +170,11 @@ func (r *Request) Build(ctx context.Context) *http.Request { req.Body = ioutil.NopCloser(stream) req.ContentLength = -1 default: - if r.stream != nil { + // HTTP Client Request must only have a non-nil body if the + // ContentLength is explicitly unknown (-1) or non-zero. The HTTP + // Client will interpret a non-nil body and ContentLength 0 as + // "unknown". This is unwanted behavior. + if req.ContentLength != 0 && r.stream != nil { req.Body = iointernal.NewSafeReadCloser(ioutil.NopCloser(stream)) } } diff --git a/transport/http/request_test.go b/transport/http/request_test.go index 685f710e8..af88fbc36 100644 --- a/transport/http/request_test.go +++ b/transport/http/request_test.go @@ -4,8 +4,9 @@ import ( "bytes" "context" "io" + "io/ioutil" "net/http" - "net/url" + "os" "strconv" "strings" "testing" @@ -19,8 +20,12 @@ func TestRequestRewindable(t *testing.T) { "rewindable": { Stream: bytes.NewReader([]byte{}), }, - "not rewindable": { - Stream: bytes.NewBuffer([]byte{}), + "empty not rewindable": { + Stream: bytes.NewBuffer([]byte{}), + // ExpectErr: "stream is not seekable", + }, + "not empty not rewindable": { + Stream: bytes.NewBuffer([]byte("abc123")), ExpectErr: "stream is not seekable", }, "nil stream": {}, @@ -28,12 +33,7 @@ func TestRequestRewindable(t *testing.T) { for name, c := range cases { t.Run(name, func(t *testing.T) { - req := &Request{ - Request: &http.Request{ - URL: &url.URL{}, - Header: http.Header{}, - }, - } + req := NewStackRequest().(*Request) req, err := req.SetStream(c.Stream) if err != nil { @@ -108,3 +108,114 @@ func TestRequestBuild_contentLength(t *testing.T) { }) } } + +func TestRequestSetStream(t *testing.T) { + cases := map[string]struct { + reader io.Reader + expectSeekable bool + expectStreamStartPos int64 + expectContentLength int64 + expectNilStream bool + expectNilBody bool + expectReqContentLength int64 + }{ + "nil stream": { + expectNilStream: true, + expectNilBody: true, + }, + "empty unseekable stream": { + reader: bytes.NewBuffer([]byte{}), + expectNilStream: true, + expectNilBody: true, + }, + "empty seekable stream": { + reader: bytes.NewReader([]byte{}), + expectContentLength: 0, + expectSeekable: true, + expectNilStream: false, + expectNilBody: true, + }, + "unseekable no len stream": { + reader: ioutil.NopCloser(bytes.NewBuffer([]byte("abc123"))), + expectContentLength: -1, + expectNilStream: false, + expectNilBody: false, + expectReqContentLength: -1, + }, + "unseekable stream": { + reader: bytes.NewBuffer([]byte("abc123")), + expectContentLength: 6, + expectNilStream: false, + expectNilBody: false, + expectReqContentLength: 6, + }, + "seekable stream": { + reader: bytes.NewReader([]byte("abc123")), + expectContentLength: 6, + expectNilStream: false, + expectSeekable: true, + expectNilBody: false, + expectReqContentLength: 6, + }, + "offset seekable stream": { + reader: func() io.Reader { + r := bytes.NewReader([]byte("abc123")) + _, _ = r.Seek(1, os.SEEK_SET) + return r + }(), + expectStreamStartPos: 1, + expectContentLength: 5, + expectSeekable: true, + expectNilStream: false, + expectNilBody: false, + expectReqContentLength: 5, + }, + "NoBody stream": { + reader: http.NoBody, + expectNilStream: true, + expectNilBody: true, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := NewStackRequest().(*Request) + req, err = req.SetStream(c.reader) + if err != nil { + t.Fatalf("expect not error, got %v", err) + } + + if e, a := c.expectSeekable, req.IsStreamSeekable(); e != a { + t.Errorf("expect %v seekable, got %v", e, a) + } + if e, a := c.expectStreamStartPos, req.streamStartPos; e != a { + t.Errorf("expect %v seek start position, got %v", e, a) + } + if e, a := c.expectNilStream, req.stream == nil; e != a { + t.Errorf("expect %v nil stream, got %v", e, a) + } + + if l, ok, err := req.StreamLength(); err != nil { + t.Fatalf("expect no stream length error, got %v", err) + } else if ok { + req.ContentLength = l + } + + if e, a := c.expectContentLength, req.ContentLength; e != a { + t.Errorf("expect %v content-length, got %v", e, a) + } + if e, a := c.expectStreamStartPos, req.streamStartPos; e != a { + t.Errorf("expect %v streamStartPos, got %v", e, a) + } + + r := req.Build(context.Background()) + if e, a := c.expectNilBody, r.Body == nil; e != a { + t.Errorf("expect %v request nil body, got %v", e, a) + } + if e, a := c.expectContentLength, req.ContentLength; e != a { + t.Errorf("expect %v request content-length, got %v", e, a) + } + }) + } +}