From 75068f93453e6ea0b5c7be5561e7ba342c695e95 Mon Sep 17 00:00:00 2001 From: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> Date: Fri, 15 Dec 2023 13:56:05 -0800 Subject: [PATCH] fix: external account user cred universe domain support (#1437) * fix: external account user cred universe domain support * refactor --------- Co-authored-by: Jin --- google/auth/credentials.py | 21 +++++++++++-- google/auth/external_account.py | 10 +----- .../auth/external_account_authorized_user.py | 12 +++++++ google/oauth2/credentials.py | 9 +----- google/oauth2/service_account.py | 10 +----- .../test_external_account_authorized_user.py | 31 +++++++++++++++++++ 6 files changed, 65 insertions(+), 28 deletions(-) diff --git a/google/auth/credentials.py b/google/auth/credentials.py index 800781c40..6e62a4b4e 100644 --- a/google/auth/credentials.py +++ b/google/auth/credentials.py @@ -188,7 +188,7 @@ def with_quota_project(self, quota_project_id): billing purposes Returns: - google.oauth2.credentials.Credentials: A new credentials instance. + google.auth.credentials.Credentials: A new credentials instance. """ raise NotImplementedError("This credential does not support quota project.") @@ -209,11 +209,28 @@ def with_token_uri(self, token_uri): token_uri (str): The uri to use for fetching/exchanging tokens Returns: - google.oauth2.credentials.Credentials: A new credentials instance. + google.auth.credentials.Credentials: A new credentials instance. """ raise NotImplementedError("This credential does not use token uri.") +class CredentialsWithUniverseDomain(Credentials): + """Abstract base for credentials supporting ``with_universe_domain`` factory""" + + def with_universe_domain(self, universe_domain): + """Returns a copy of these credentials with a modified universe domain. + + Args: + universe_domain (str): The universe domain to use + + Returns: + google.auth.credentials.Credentials: A new credentials instance. + """ + raise NotImplementedError( + "This credential does not support with_universe_domain." + ) + + class AnonymousCredentials(Credentials): """Credentials that do not provide any authentication information. diff --git a/google/auth/external_account.py b/google/auth/external_account.py index e7fed8695..c314ea799 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -415,16 +415,8 @@ def with_token_uri(self, token_uri): new_cred._metrics_options = self._metrics_options return new_cred + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) def with_universe_domain(self, universe_domain): - """Create a copy of these credentials with the given universe domain. - - Args: - universe_domain (str): The universe domain value. - - Returns: - google.auth.external_account.Credentials: A new credentials - instance. - """ kwargs = self._constructor_args() kwargs.update(universe_domain=universe_domain) new_cred = self.__class__(**kwargs) diff --git a/google/auth/external_account_authorized_user.py b/google/auth/external_account_authorized_user.py index a2d4edf6f..55230103f 100644 --- a/google/auth/external_account_authorized_user.py +++ b/google/auth/external_account_authorized_user.py @@ -43,6 +43,7 @@ from google.oauth2 import sts from google.oauth2 import utils +_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" _EXTERNAL_ACCOUNT_AUTHORIZED_USER_JSON_TYPE = "external_account_authorized_user" @@ -75,6 +76,7 @@ def __init__( revoke_url=None, scopes=None, quota_project_id=None, + universe_domain=_DEFAULT_UNIVERSE_DOMAIN, ): """Instantiates a external account authorized user credentials object. @@ -98,6 +100,8 @@ def __init__( quota_project_id (str): The optional project ID used for quota and billing. This project may be different from the project used to create the credentials. + universe_domain (Optional[str]): The universe domain. The default value + is googleapis.com. Returns: google.auth.external_account_authorized_user.Credentials: The @@ -116,6 +120,7 @@ def __init__( self._revoke_url = revoke_url self._quota_project_id = quota_project_id self._scopes = scopes + self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN if not self.valid and not self.can_refresh: raise exceptions.InvalidOperation( @@ -162,6 +167,7 @@ def constructor_args(self): "revoke_url": self._revoke_url, "scopes": self._scopes, "quota_project_id": self._quota_project_id, + "universe_domain": self._universe_domain, } @property @@ -297,6 +303,12 @@ def with_token_uri(self, token_uri): kwargs.update(token_url=token_uri) return self.__class__(**kwargs) + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) + def with_universe_domain(self, universe_domain): + kwargs = self.constructor_args() + kwargs.update(universe_domain=universe_domain) + return self.__class__(**kwargs) + @classmethod def from_info(cls, info, **kwargs): """Creates a Credentials instance from parsed external account info. diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index a5c93ecc2..7d327c110 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -302,15 +302,8 @@ def with_token_uri(self, token_uri): universe_domain=self._universe_domain, ) + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) def with_universe_domain(self, universe_domain): - """Create a copy of the credential with the given universe domain. - - Args: - universe_domain (str): The universe domain value. - - Returns: - google.oauth2.credentials.Credentials: A new credentials instance. - """ return self.__class__( self.token, diff --git a/google/oauth2/service_account.py b/google/oauth2/service_account.py index 68db41af4..4502c6f68 100644 --- a/google/oauth2/service_account.py +++ b/google/oauth2/service_account.py @@ -325,16 +325,8 @@ def with_always_use_jwt_access(self, always_use_jwt_access): cred._always_use_jwt_access = always_use_jwt_access return cred + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) def with_universe_domain(self, universe_domain): - """Create a copy of these credentials with the given universe domain. - - Args: - universe_domain (str): The universe domain value. - - Returns: - google.auth.service_account.Credentials: A new credentials - instance. - """ cred = self._make_copy() cred._universe_domain = universe_domain if universe_domain != _DEFAULT_UNIVERSE_DOMAIN: diff --git a/tests/test_external_account_authorized_user.py b/tests/test_external_account_authorized_user.py index 7ffd5078c..7213a2348 100644 --- a/tests/test_external_account_authorized_user.py +++ b/tests/test_external_account_authorized_user.py @@ -44,6 +44,8 @@ BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" SCOPES = ["email", "profile"] NOW = datetime.datetime(1990, 8, 27, 6, 54, 30) +FAKE_UNIVERSE_DOMAIN = "fake-universe-domain" +DEFAULT_UNIVERSE_DOMAIN = external_account_authorized_user._DEFAULT_UNIVERSE_DOMAIN class TestCredentials(object): @@ -98,6 +100,7 @@ def test_default_state(self): assert creds.refresh_token == REFRESH_TOKEN assert creds.audience == AUDIENCE assert creds.token_url == TOKEN_URL + assert creds.universe_domain == DEFAULT_UNIVERSE_DOMAIN def test_basic_create(self): creds = external_account_authorized_user.Credentials( @@ -105,6 +108,7 @@ def test_basic_create(self): expiry=datetime.datetime.max, scopes=SCOPES, revoke_url=REVOKE_URL, + universe_domain=FAKE_UNIVERSE_DOMAIN, ) assert creds.expiry == datetime.datetime.max @@ -115,6 +119,7 @@ def test_basic_create(self): assert creds.scopes == SCOPES assert creds.is_user assert creds.revoke_url == REVOKE_URL + assert creds.universe_domain == FAKE_UNIVERSE_DOMAIN def test_stunted_create_no_refresh_token(self): with pytest.raises(ValueError) as excinfo: @@ -339,6 +344,7 @@ def test_info(self): assert info["token_info_url"] == TOKEN_INFO_URL assert info["client_id"] == CLIENT_ID assert info["client_secret"] == CLIENT_SECRET + assert info["universe_domain"] == DEFAULT_UNIVERSE_DOMAIN assert "token" not in info assert "expiry" not in info assert "revoke_url" not in info @@ -350,6 +356,7 @@ def test_info_full(self): expiry=NOW, revoke_url=REVOKE_URL, quota_project_id=QUOTA_PROJECT_ID, + universe_domain=FAKE_UNIVERSE_DOMAIN, ) info = creds.info @@ -363,6 +370,7 @@ def test_info_full(self): assert info["expiry"] == NOW.isoformat() + "Z" assert info["revoke_url"] == REVOKE_URL assert info["quota_project_id"] == QUOTA_PROJECT_ID + assert info["universe_domain"] == FAKE_UNIVERSE_DOMAIN def test_to_json(self): creds = self.make_credentials() @@ -375,6 +383,7 @@ def test_to_json(self): assert info["token_info_url"] == TOKEN_INFO_URL assert info["client_id"] == CLIENT_ID assert info["client_secret"] == CLIENT_SECRET + assert info["universe_domain"] == DEFAULT_UNIVERSE_DOMAIN assert "token" not in info assert "expiry" not in info assert "revoke_url" not in info @@ -386,6 +395,7 @@ def test_to_json_full(self): expiry=NOW, revoke_url=REVOKE_URL, quota_project_id=QUOTA_PROJECT_ID, + universe_domain=FAKE_UNIVERSE_DOMAIN, ) json_info = creds.to_json() info = json.loads(json_info) @@ -400,6 +410,7 @@ def test_to_json_full(self): assert info["expiry"] == NOW.isoformat() + "Z" assert info["revoke_url"] == REVOKE_URL assert info["quota_project_id"] == QUOTA_PROJECT_ID + assert info["universe_domain"] == FAKE_UNIVERSE_DOMAIN def test_to_json_full_with_strip(self): creds = self.make_credentials( @@ -467,6 +478,26 @@ def test_with_token_uri(self): assert new_creds._revoke_url == creds._revoke_url assert new_creds._quota_project_id == creds._quota_project_id + def test_with_universe_domain(self): + creds = self.make_credentials( + token=ACCESS_TOKEN, + expiry=NOW, + revoke_url=REVOKE_URL, + quota_project_id=QUOTA_PROJECT_ID, + ) + new_creds = creds.with_universe_domain(FAKE_UNIVERSE_DOMAIN) + assert new_creds._audience == creds._audience + assert new_creds._refresh_token == creds._refresh_token + assert new_creds._token_url == creds._token_url + assert new_creds._token_info_url == creds._token_info_url + assert new_creds._client_id == creds._client_id + assert new_creds._client_secret == creds._client_secret + assert new_creds.token == creds.token + assert new_creds.expiry == creds.expiry + assert new_creds._revoke_url == creds._revoke_url + assert new_creds._quota_project_id == QUOTA_PROJECT_ID + assert new_creds.universe_domain == FAKE_UNIVERSE_DOMAIN + def test_from_file_required_options_only(self, tmpdir): from_creds = self.make_credentials() config_file = tmpdir.join("config.json")