Skip to content

Commit

Permalink
fix: external account user cred universe domain support (#1437)
Browse files Browse the repository at this point in the history
* fix: external account user cred universe domain support

* refactor

---------

Co-authored-by: Jin <[email protected]>
  • Loading branch information
arithmetic1728 and BigTailWolf committed Dec 15, 2023
1 parent 0afc61a commit 75068f9
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 28 deletions.
21 changes: 19 additions & 2 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def with_quota_project(self, quota_project_id):
billing purposes
Returns:
google.oauth2.credentials.Credentials: A new credentials instance.
google.auth.credentials.Credentials: A new credentials instance.
"""
raise NotImplementedError("This credential does not support quota project.")

Expand All @@ -209,11 +209,28 @@ def with_token_uri(self, token_uri):
token_uri (str): The uri to use for fetching/exchanging tokens
Returns:
google.oauth2.credentials.Credentials: A new credentials instance.
google.auth.credentials.Credentials: A new credentials instance.
"""
raise NotImplementedError("This credential does not use token uri.")


class CredentialsWithUniverseDomain(Credentials):
"""Abstract base for credentials supporting ``with_universe_domain`` factory"""

def with_universe_domain(self, universe_domain):
"""Returns a copy of these credentials with a modified universe domain.
Args:
universe_domain (str): The universe domain to use
Returns:
google.auth.credentials.Credentials: A new credentials instance.
"""
raise NotImplementedError(
"This credential does not support with_universe_domain."
)


class AnonymousCredentials(Credentials):
"""Credentials that do not provide any authentication information.
Expand Down
10 changes: 1 addition & 9 deletions google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,16 +415,8 @@ def with_token_uri(self, token_uri):
new_cred._metrics_options = self._metrics_options
return new_cred

@_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain)
def with_universe_domain(self, universe_domain):
"""Create a copy of these credentials with the given universe domain.
Args:
universe_domain (str): The universe domain value.
Returns:
google.auth.external_account.Credentials: A new credentials
instance.
"""
kwargs = self._constructor_args()
kwargs.update(universe_domain=universe_domain)
new_cred = self.__class__(**kwargs)
Expand Down
12 changes: 12 additions & 0 deletions google/auth/external_account_authorized_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from google.oauth2 import sts
from google.oauth2 import utils

_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
_EXTERNAL_ACCOUNT_AUTHORIZED_USER_JSON_TYPE = "external_account_authorized_user"


Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
revoke_url=None,
scopes=None,
quota_project_id=None,
universe_domain=_DEFAULT_UNIVERSE_DOMAIN,
):
"""Instantiates a external account authorized user credentials object.
Expand All @@ -98,6 +100,8 @@ def __init__(
quota_project_id (str): The optional project ID used for quota and billing.
This project may be different from the project used to
create the credentials.
universe_domain (Optional[str]): The universe domain. The default value
is googleapis.com.
Returns:
google.auth.external_account_authorized_user.Credentials: The
Expand All @@ -116,6 +120,7 @@ def __init__(
self._revoke_url = revoke_url
self._quota_project_id = quota_project_id
self._scopes = scopes
self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN

if not self.valid and not self.can_refresh:
raise exceptions.InvalidOperation(
Expand Down Expand Up @@ -162,6 +167,7 @@ def constructor_args(self):
"revoke_url": self._revoke_url,
"scopes": self._scopes,
"quota_project_id": self._quota_project_id,
"universe_domain": self._universe_domain,
}

@property
Expand Down Expand Up @@ -297,6 +303,12 @@ def with_token_uri(self, token_uri):
kwargs.update(token_url=token_uri)
return self.__class__(**kwargs)

@_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain)
def with_universe_domain(self, universe_domain):
kwargs = self.constructor_args()
kwargs.update(universe_domain=universe_domain)
return self.__class__(**kwargs)

@classmethod
def from_info(cls, info, **kwargs):
"""Creates a Credentials instance from parsed external account info.
Expand Down
9 changes: 1 addition & 8 deletions google/oauth2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,8 @@ def with_token_uri(self, token_uri):
universe_domain=self._universe_domain,
)

@_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain)
def with_universe_domain(self, universe_domain):
"""Create a copy of the credential with the given universe domain.
Args:
universe_domain (str): The universe domain value.
Returns:
google.oauth2.credentials.Credentials: A new credentials instance.
"""

return self.__class__(
self.token,
Expand Down
10 changes: 1 addition & 9 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,8 @@ def with_always_use_jwt_access(self, always_use_jwt_access):
cred._always_use_jwt_access = always_use_jwt_access
return cred

@_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain)
def with_universe_domain(self, universe_domain):
"""Create a copy of these credentials with the given universe domain.
Args:
universe_domain (str): The universe domain value.
Returns:
google.auth.service_account.Credentials: A new credentials
instance.
"""
cred = self._make_copy()
cred._universe_domain = universe_domain
if universe_domain != _DEFAULT_UNIVERSE_DOMAIN:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_external_account_authorized_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ="
SCOPES = ["email", "profile"]
NOW = datetime.datetime(1990, 8, 27, 6, 54, 30)
FAKE_UNIVERSE_DOMAIN = "fake-universe-domain"
DEFAULT_UNIVERSE_DOMAIN = external_account_authorized_user._DEFAULT_UNIVERSE_DOMAIN


class TestCredentials(object):
Expand Down Expand Up @@ -98,13 +100,15 @@ def test_default_state(self):
assert creds.refresh_token == REFRESH_TOKEN
assert creds.audience == AUDIENCE
assert creds.token_url == TOKEN_URL
assert creds.universe_domain == DEFAULT_UNIVERSE_DOMAIN

def test_basic_create(self):
creds = external_account_authorized_user.Credentials(
token=ACCESS_TOKEN,
expiry=datetime.datetime.max,
scopes=SCOPES,
revoke_url=REVOKE_URL,
universe_domain=FAKE_UNIVERSE_DOMAIN,
)

assert creds.expiry == datetime.datetime.max
Expand All @@ -115,6 +119,7 @@ def test_basic_create(self):
assert creds.scopes == SCOPES
assert creds.is_user
assert creds.revoke_url == REVOKE_URL
assert creds.universe_domain == FAKE_UNIVERSE_DOMAIN

def test_stunted_create_no_refresh_token(self):
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -339,6 +344,7 @@ def test_info(self):
assert info["token_info_url"] == TOKEN_INFO_URL
assert info["client_id"] == CLIENT_ID
assert info["client_secret"] == CLIENT_SECRET
assert info["universe_domain"] == DEFAULT_UNIVERSE_DOMAIN
assert "token" not in info
assert "expiry" not in info
assert "revoke_url" not in info
Expand All @@ -350,6 +356,7 @@ def test_info_full(self):
expiry=NOW,
revoke_url=REVOKE_URL,
quota_project_id=QUOTA_PROJECT_ID,
universe_domain=FAKE_UNIVERSE_DOMAIN,
)
info = creds.info

Expand All @@ -363,6 +370,7 @@ def test_info_full(self):
assert info["expiry"] == NOW.isoformat() + "Z"
assert info["revoke_url"] == REVOKE_URL
assert info["quota_project_id"] == QUOTA_PROJECT_ID
assert info["universe_domain"] == FAKE_UNIVERSE_DOMAIN

def test_to_json(self):
creds = self.make_credentials()
Expand All @@ -375,6 +383,7 @@ def test_to_json(self):
assert info["token_info_url"] == TOKEN_INFO_URL
assert info["client_id"] == CLIENT_ID
assert info["client_secret"] == CLIENT_SECRET
assert info["universe_domain"] == DEFAULT_UNIVERSE_DOMAIN
assert "token" not in info
assert "expiry" not in info
assert "revoke_url" not in info
Expand All @@ -386,6 +395,7 @@ def test_to_json_full(self):
expiry=NOW,
revoke_url=REVOKE_URL,
quota_project_id=QUOTA_PROJECT_ID,
universe_domain=FAKE_UNIVERSE_DOMAIN,
)
json_info = creds.to_json()
info = json.loads(json_info)
Expand All @@ -400,6 +410,7 @@ def test_to_json_full(self):
assert info["expiry"] == NOW.isoformat() + "Z"
assert info["revoke_url"] == REVOKE_URL
assert info["quota_project_id"] == QUOTA_PROJECT_ID
assert info["universe_domain"] == FAKE_UNIVERSE_DOMAIN

def test_to_json_full_with_strip(self):
creds = self.make_credentials(
Expand Down Expand Up @@ -467,6 +478,26 @@ def test_with_token_uri(self):
assert new_creds._revoke_url == creds._revoke_url
assert new_creds._quota_project_id == creds._quota_project_id

def test_with_universe_domain(self):
creds = self.make_credentials(
token=ACCESS_TOKEN,
expiry=NOW,
revoke_url=REVOKE_URL,
quota_project_id=QUOTA_PROJECT_ID,
)
new_creds = creds.with_universe_domain(FAKE_UNIVERSE_DOMAIN)
assert new_creds._audience == creds._audience
assert new_creds._refresh_token == creds._refresh_token
assert new_creds._token_url == creds._token_url
assert new_creds._token_info_url == creds._token_info_url
assert new_creds._client_id == creds._client_id
assert new_creds._client_secret == creds._client_secret
assert new_creds.token == creds.token
assert new_creds.expiry == creds.expiry
assert new_creds._revoke_url == creds._revoke_url
assert new_creds._quota_project_id == QUOTA_PROJECT_ID
assert new_creds.universe_domain == FAKE_UNIVERSE_DOMAIN

def test_from_file_required_options_only(self, tmpdir):
from_creds = self.make_credentials()
config_file = tmpdir.join("config.json")
Expand Down

0 comments on commit 75068f9

Please sign in to comment.