Skip to content

Commit

Permalink
Mock Credentials in AWSPrometheusRemoteWriteExporter Tests (open-tele…
Browse files Browse the repository at this point in the history
…metry#1755)

We are adding more tests to the AWSPrometheusRemoteWriteExporter in this PR. These tests include mocking the credentials and test for negative error cases. In this PR, we refactored `auth.go` so that we can provide mocked credentials to the SigV4 signer and added a test for no credentials.

**Testing:** 

* `auth_test.go` was changed to provide mock credentials
* `factory_test.go` was changed to add a comment to warn user's not to set their credentials with environment variables in production code
  • Loading branch information
Jason Liu committed Dec 10, 2020
1 parent a84baea commit 241c509
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 26 deletions.
25 changes: 18 additions & 7 deletions exporter/awsprometheusremotewriteexporter/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package awsprometheusremotewriteexporter

import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
)
Expand All @@ -30,7 +32,7 @@ import (
type signingRoundTripper struct {
transport http.RoundTripper
signer *v4.Signer
cfg *aws.Config
region string
service string
}

Expand All @@ -53,7 +55,7 @@ func (si *signingRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
req2 := cloneRequest(req)

// Sign the request
_, err = si.signer.Sign(req2, body, si.service, *si.cfg.Region, time.Now())
_, err = si.signer.Sign(req2, body, si.service, si.region, time.Now())
if err != nil {
return nil, err
}
Expand All @@ -68,10 +70,6 @@ func (si *signingRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
}

func newSigningRoundTripper(auth AuthConfig, next http.RoundTripper) (http.RoundTripper, error) {
if !isValidAuth(auth) {
return next, nil
}

sess, err := session.NewSession(&aws.Config{
Region: aws.String(auth.Region)},
)
Expand All @@ -85,12 +83,25 @@ func newSigningRoundTripper(auth AuthConfig, next http.RoundTripper) (http.Round

// Get Credentials, either from ./aws or from environmental variables
creds := sess.Config.Credentials

return createSigningRoundTripperWithCredentials(auth, creds, next)
}

func createSigningRoundTripperWithCredentials(auth AuthConfig, creds *credentials.Credentials, next http.RoundTripper) (http.RoundTripper, error) {
if !isValidAuth(auth) {
return next, nil
}

if creds == nil {
return nil, errors.New("no AWS credentials exist")
}

signer := v4.NewSigner(creds)

rt := signingRoundTripper{
transport: next,
signer: signer,
cfg: sess.Config,
region: auth.Region,
service: auth.Service,
}

Expand Down
49 changes: 32 additions & 17 deletions exporter/awsprometheusremotewriteexporter/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/collector/config/confighttp"
Expand All @@ -32,8 +32,8 @@ import (

func TestRequestSignature(t *testing.T) {
// Some form of AWS credentials must be set up for tests to succeed
os.Setenv("AWS_ACCESS_KEY", "string")
os.Setenv("AWS_SECRET_ACCESS_KEY", "string2")
awsCreds := fetchMockCredentials()
authConfig := AuthConfig{Region: "region", Service: "service"}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := v4.GetSignedRequestSignature(r)
Expand All @@ -52,8 +52,7 @@ func TestRequestSignature(t *testing.T) {
WriteBufferSize: 0,
Timeout: 0,
CustomRoundTripper: func(next http.RoundTripper) (http.RoundTripper, error) {
settings := AuthConfig{Region: "region", Service: "service"}
return newSigningRoundTripper(settings, next)
return createSigningRoundTripperWithCredentials(authConfig, awsCreds, next)
},
}
client, _ := setting.ToClient()
Expand All @@ -71,8 +70,7 @@ func (ert *ErrorRoundTripper) RoundTrip(r *http.Request) (*http.Response, error)

func TestRoundTrip(t *testing.T) {
// Some form of AWS credentials must be set up for tests to succeed
os.Setenv("AWS_ACCESS_KEY", "string")
os.Setenv("AWS_SECRET_ACCESS_KEY", "string2")
awsCreds := fetchMockCredentials()

defaultRoundTripper := (http.RoundTripper)(http.DefaultTransport.(*http.Transport).Clone())
errorRoundTripper := &ErrorRoundTripper{}
Expand Down Expand Up @@ -103,8 +101,8 @@ func TestRoundTrip(t *testing.T) {
}))
defer server.Close()
serverURL, _ := url.Parse(server.URL)
settings := AuthConfig{Region: "region", Service: "service"}
rt, err := newSigningRoundTripper(settings, tt.rt)
authConfig := AuthConfig{Region: "region", Service: "service"}
rt, err := createSigningRoundTripperWithCredentials(authConfig, awsCreds, tt.rt)
assert.NoError(t, err)
req, err := http.NewRequest("POST", serverURL.String(), strings.NewReader(""))
assert.NoError(t, err)
Expand All @@ -120,39 +118,50 @@ func TestRoundTrip(t *testing.T) {
}
}

func TestNewSigningRoundTripper(t *testing.T) {
func TestCreateSigningRoundTripperWithCredentials(t *testing.T) {

defaultRoundTripper := (http.RoundTripper)(http.DefaultTransport.(*http.Transport).Clone())

// Some form of AWS credentials must be set up for tests to succeed
os.Setenv("AWS_ACCESS_KEY", "string")
os.Setenv("AWS_SECRET_ACCESS_KEY", "string2")
awsCreds := fetchMockCredentials()

tests := []struct {
name string
creds *credentials.Credentials
roundTripper http.RoundTripper
settings AuthConfig
authConfig AuthConfig
authApplied bool
returnError bool
}{
{
"success_case",
awsCreds,
defaultRoundTripper,
AuthConfig{Region: "region", Service: "service"},
true,
false,
},
{
"success_case_no_auth_applied",
awsCreds,
defaultRoundTripper,
AuthConfig{Region: "", Service: ""},
false,
false,
},
{
"no_credentials_provided_error",
nil,
defaultRoundTripper,
AuthConfig{Region: "region", Service: "service"},
true,
true,
},
}
// run tests
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rtp, err := newSigningRoundTripper(tt.settings, tt.roundTripper)
rtp, err := createSigningRoundTripperWithCredentials(tt.authConfig, tt.creds, tt.roundTripper)
if tt.returnError {
assert.Error(t, err)
return
Expand All @@ -161,7 +170,7 @@ func TestNewSigningRoundTripper(t *testing.T) {
if tt.authApplied {
sRtp := rtp.(*signingRoundTripper)
assert.Equal(t, sRtp.transport, tt.roundTripper)
assert.Equal(t, tt.settings.Service, sRtp.service)
assert.Equal(t, tt.authConfig.Service, sRtp.service)
} else {
assert.Equal(t, rtp, tt.roundTripper)
}
Expand All @@ -170,10 +179,10 @@ func TestNewSigningRoundTripper(t *testing.T) {
}

func TestCloneRequest(t *testing.T) {
req1, err := http.NewRequest("GET", "http:https://example.com", nil)
req1, err := http.NewRequest("GET", "https:https://example.com", nil)
assert.NoError(t, err)

req2, err := http.NewRequest("GET", "http:https://example.com", nil)
req2, err := http.NewRequest("GET", "https:https://example.com", nil)
assert.NoError(t, err)
req2.Header.Add("Header1", "val1")

Expand Down Expand Up @@ -201,3 +210,9 @@ func TestCloneRequest(t *testing.T) {
})
}
}

func fetchMockCredentials() *credentials.Credentials {
return credentials.NewStaticCredentials("MOCK_AWS_ACCESS_KEY",
"MOCK_AWS_SECRET_ACCESS_KEY",
"MOCK_TOKEN")
}
6 changes: 4 additions & 2 deletions exporter/awsprometheusremotewriteexporter/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ func TestCreateMetricsExporter(t *testing.T) {
validConfigWithAuth.AuthConfig = AuthConfig{Region: "region", Service: "service"}

// Some form of AWS credentials chain required to test valid auth case
os.Setenv("AWS_ACCESS_KEY", "string")
os.Setenv("AWS_SECRET_ACCESS_KEY", "string2")
// This is a set of mock credentials strictly for testing purposes. Users
// should not set their credentials like this in production.
os.Setenv("AWS_ACCESS_KEY", "mock_value")
os.Setenv("AWS_SECRET_ACCESS_KEY", "mock_value2")

invalidConfigWithAuth := af.CreateDefaultConfig().(*Config)
invalidConfigWithAuth.AuthConfig = AuthConfig{Region: "", Service: "service"}
Expand Down

0 comments on commit 241c509

Please sign in to comment.