From eb5c8682071600cf0d7386755d5cd0c30249f88d Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Tue, 28 May 2024 15:24:44 -0700 Subject: [PATCH 1/2] fix(cba): Update credsNewAuth to support oauth2 over mTLS --- internal/creds.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/internal/creds.go b/internal/creds.go index b6dbace4c9..954f310454 100644 --- a/internal/creds.go +++ b/internal/creds.go @@ -80,13 +80,31 @@ func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credenti aud = settings.DefaultAudience } + // Determine configurations for the OAuth2 transport, which is separate from the API transport. + // The OAuth2 transport and endpoint will be configured for mTLS if applicable. + clientCertSource, err := getClientCertificateSource(settings) + if err != nil { + return nil, err + } + tokenURL := oAuth2Endpoint(clientCertSource) + var oauth2Client *http.Client + if clientCertSource != nil { + tlsConfig := &tls.Config{ + GetClientCertificate: clientCertSource, + } + oauth2Client = customHTTPClient(tlsConfig) + } else { + oauth2Client = oauth2.NewClient(ctx, nil) + } + creds, err := credentials.DetectDefault(&credentials.DetectOptions{ Scopes: scopes, Audience: aud, CredentialsFile: settings.CredentialsFile, CredentialsJSON: settings.CredentialsJSON, UseSelfSignedJWT: useSelfSignedJWT, - Client: oauth2.NewClient(ctx, nil), + TokenURL: tokenURL, + Client: oauth2Client, }) if err != nil { return nil, err From 7e3a91922301fa71fb3d4cfa8550c4fb65b95746 Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Thu, 30 May 2024 14:55:54 -0700 Subject: [PATCH 2/2] fix(cba): Update newAuth flow to support token exchange over mTLS --- internal/creds.go | 48 +++++++++++++++++++++--------------------- transport/grpc/dial.go | 8 ++++++- transport/http/dial.go | 7 +++++- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/internal/creds.go b/internal/creds.go index 954f310454..e6c4fe90d4 100644 --- a/internal/creds.go +++ b/internal/creds.go @@ -42,6 +42,26 @@ func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) { return creds, nil } +// GetOAuth2Configuration determines configurations for the OAuth2 transport, which is separate from the API transport. +// The OAuth2 transport and endpoint will be configured for mTLS if applicable. +func GetOAuth2Configuration(ctx context.Context, settings *DialSettings) (string, *http.Client, error) { + clientCertSource, err := getClientCertificateSource(settings) + if err != nil { + return "", nil, err + } + tokenURL := oAuth2Endpoint(clientCertSource) + var oauth2Client *http.Client + if clientCertSource != nil { + tlsConfig := &tls.Config{ + GetClientCertificate: clientCertSource, + } + oauth2Client = customHTTPClient(tlsConfig) + } else { + oauth2Client = oauth2.NewClient(ctx, nil) + } + return tokenURL, oauth2Client, nil +} + func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credentials, error) { // Preserve old options behavior if settings.InternalCredentials != nil { @@ -80,23 +100,10 @@ func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credenti aud = settings.DefaultAudience } - // Determine configurations for the OAuth2 transport, which is separate from the API transport. - // The OAuth2 transport and endpoint will be configured for mTLS if applicable. - clientCertSource, err := getClientCertificateSource(settings) + tokenURL, oauth2Client, err := GetOAuth2Configuration(ctx, settings) if err != nil { return nil, err } - tokenURL := oAuth2Endpoint(clientCertSource) - var oauth2Client *http.Client - if clientCertSource != nil { - tlsConfig := &tls.Config{ - GetClientCertificate: clientCertSource, - } - oauth2Client = customHTTPClient(tlsConfig) - } else { - oauth2Client = oauth2.NewClient(ctx, nil) - } - creds, err := credentials.DetectDefault(&credentials.DetectOptions{ Scopes: scopes, Audience: aud, @@ -165,19 +172,12 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g var params google.CredentialsParams params.Scopes = ds.GetScopes() - // Determine configurations for the OAuth2 transport, which is separate from the API transport. - // The OAuth2 transport and endpoint will be configured for mTLS if applicable. - clientCertSource, err := getClientCertificateSource(ds) + tokenURL, oauth2Client, err := GetOAuth2Configuration(ctx, ds) if err != nil { return nil, err } - params.TokenURL = oAuth2Endpoint(clientCertSource) - if clientCertSource != nil { - tlsConfig := &tls.Config{ - GetClientCertificate: clientCertSource, - } - ctx = context.WithValue(ctx, oauth2.HTTPClient, customHTTPClient(tlsConfig)) - } + params.TokenURL = tokenURL + ctx = context.WithValue(ctx, oauth2.HTTPClient, oauth2Client) // By default, a standard OAuth 2.0 token source is created cred, err := google.CredentialsFromJSONWithParams(ctx, data, params) diff --git a/transport/grpc/dial.go b/transport/grpc/dial.go index 2e66d02b37..2d4f90c9c1 100644 --- a/transport/grpc/dial.go +++ b/transport/grpc/dial.go @@ -218,6 +218,11 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna defaultEndpointTemplate = ds.DefaultEndpoint } + tokenURL, oauth2Client, err := internal.GetOAuth2Configuration(ctx, ds) + if err != nil { + return nil, err + } + pool, err := grpctransport.Dial(ctx, secure, &grpctransport.Options{ DisableTelemetry: ds.TelemetryDisabled, DisableAuthentication: ds.NoAuth, @@ -231,7 +236,8 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna Audience: aud, CredentialsFile: ds.CredentialsFile, CredentialsJSON: ds.CredentialsJSON, - Client: oauth2.NewClient(ctx, nil), + TokenURL: tokenURL, + Client: oauth2Client, }, InternalOptions: &grpctransport.InternalOptions{ EnableNonDefaultSAForDirectPath: ds.AllowNonDefaultServiceAccount, diff --git a/transport/http/dial.go b/transport/http/dial.go index d1cd83b62d..a36e24315b 100644 --- a/transport/http/dial.go +++ b/transport/http/dial.go @@ -107,6 +107,10 @@ func newClientNewAuth(ctx context.Context, base http.RoundTripper, ds *internal. if ds.RequestReason != "" { headers.Set("X-goog-request-reason", ds.RequestReason) } + tokenURL, oauth2Client, err := internal.GetOAuth2Configuration(ctx, ds) + if err != nil { + return nil, err + } client, err := httptransport.NewClient(&httptransport.Options{ DisableTelemetry: ds.TelemetryDisabled, DisableAuthentication: ds.NoAuth, @@ -121,7 +125,8 @@ func newClientNewAuth(ctx context.Context, base http.RoundTripper, ds *internal. Audience: aud, CredentialsFile: ds.CredentialsFile, CredentialsJSON: ds.CredentialsJSON, - Client: oauth2.NewClient(ctx, nil), + TokenURL: tokenURL, + Client: oauth2Client, }, InternalOptions: &httptransport.InternalOptions{ EnableJWTWithScope: ds.EnableJwtWithScope,