Skip to content

Commit

Permalink
feat: add experimental GDCH support (#1044)
Browse files Browse the repository at this point in the history
* feat: add experimental GDCH support

* use ec key

* update comment

* Update google/oauth2/gdch_credentials.py

* fix

* add project, update payload
  • Loading branch information
arithmetic1728 committed Jun 14, 2022
1 parent 87d41ae commit 94fb5e2
Show file tree
Hide file tree
Showing 9 changed files with 591 additions and 23 deletions.
30 changes: 29 additions & 1 deletion google/auth/_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
_SERVICE_ACCOUNT_TYPE = "service_account"
_EXTERNAL_ACCOUNT_TYPE = "external_account"
_IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account"
_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account"
_VALID_TYPES = (
_AUTHORIZED_USER_TYPE,
_SERVICE_ACCOUNT_TYPE,
_EXTERNAL_ACCOUNT_TYPE,
_IMPERSONATED_SERVICE_ACCOUNT_TYPE,
_GDCH_SERVICE_ACCOUNT_TYPE,
)

# Help message when no credentials can be found.
Expand Down Expand Up @@ -134,6 +136,8 @@ def load_credentials_from_file(
def _load_credentials_from_info(
filename, info, scopes, default_scopes, quota_project_id, request
):
from google.auth.credentials import CredentialsWithQuotaProject

credential_type = info.get("type")

if credential_type == _AUTHORIZED_USER_TYPE:
Expand All @@ -158,14 +162,17 @@ def _load_credentials_from_info(
credentials, project_id = _get_impersonated_service_account_credentials(
filename, info, scopes
)
elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE:
credentials, project_id = _get_gdch_service_account_credentials(filename, info)
else:
raise exceptions.DefaultCredentialsError(
"The file {file} does not have a valid type. "
"Type is {type}, expected one of {valid_types}.".format(
file=filename, type=credential_type, valid_types=_VALID_TYPES
)
)
credentials = _apply_quota_project_id(credentials, quota_project_id)
if isinstance(credentials, CredentialsWithQuotaProject):
credentials = _apply_quota_project_id(credentials, quota_project_id)
return credentials, project_id


Expand Down Expand Up @@ -421,6 +428,20 @@ def _get_impersonated_service_account_credentials(filename, info, scopes):
return credentials, None


def _get_gdch_service_account_credentials(filename, info):
from google.oauth2 import gdch_credentials

try:
credentials = gdch_credentials.ServiceAccountCredentials.from_service_account_info(
info
)
except ValueError as caught_exc:
msg = "Failed to load GDCH service account credentials from {}".format(filename)
new_exc = exceptions.DefaultCredentialsError(msg, caught_exc)
six.raise_from(new_exc, caught_exc)
return credentials, info.get("project")


def _apply_quota_project_id(credentials, quota_project_id):
if quota_project_id:
credentials = credentials.with_quota_project(quota_project_id)
Expand Down Expand Up @@ -456,6 +477,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
endpoint.
The project ID returned in this case is the one corresponding to the
underlying workload identity pool resource if determinable.
If the environment variable is set to the path of a valid GDCH service
account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH
credential will be returned. The project ID returned is the project
specified in the JSON file.
2. If the `Google Cloud SDK`_ is installed and has application default
credentials set they are loaded and returned.
Expand Down Expand Up @@ -490,6 +516,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
.. _Metadata Service: https://cloud.google.com/compute/docs\
/storing-retrieving-metadata
.. _Cloud Run: https://cloud.google.com/run
.. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\
/hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted
Example::
Expand Down
15 changes: 11 additions & 4 deletions google/auth/_service_account_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from google.auth import crypt


def from_dict(data, require=None):
def from_dict(data, require=None, use_rsa_signer=True):
"""Validates a dictionary containing Google service account data.
Creates and returns a :class:`google.auth.crypt.Signer` instance from the
Expand All @@ -32,6 +32,8 @@ def from_dict(data, require=None):
data (Mapping[str, str]): The service account data
require (Sequence[str]): List of keys required to be present in the
info.
use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer.
We use RSA signer by default.
Returns:
google.auth.crypt.Signer: A signer created from the private key in the
Expand All @@ -52,23 +54,28 @@ def from_dict(data, require=None):
)

# Create a signer.
signer = crypt.RSASigner.from_service_account_info(data)
if use_rsa_signer:
signer = crypt.RSASigner.from_service_account_info(data)
else:
signer = crypt.ES256Signer.from_service_account_info(data)

return signer


def from_filename(filename, require=None):
def from_filename(filename, require=None, use_rsa_signer=True):
"""Reads a Google service account JSON file and returns its parsed info.
Args:
filename (str): The path to the service account .json file.
require (Sequence[str]): List of keys required to be present in the
info.
use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer.
We use RSA signer by default.
Returns:
Tuple[ Mapping[str, str], google.auth.crypt.Signer ]: The verified
info and a signer instance.
"""
with io.open(filename, "r", encoding="utf-8") as json_file:
data = json.load(json_file)
return data, from_dict(data, require=require)
return data, from_dict(data, require=require, use_rsa_signer=use_rsa_signer)
58 changes: 41 additions & 17 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ def _handle_error_response(response_data):
"""Translates an error response into an exception.
Args:
response_data (Mapping): The decoded response data.
response_data (Mapping | str): The decoded response data.
Raises:
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
if isinstance(response_data, six.string_types):
raise exceptions.RefreshError(response_data)
try:
error_details = "{}: {}".format(
response_data["error"], response_data.get("error_description")
Expand Down Expand Up @@ -79,7 +81,7 @@ def _parse_expiry(response_data):


def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
request, token_uri, body, access_token=None, use_json=False, **kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Expand All @@ -93,6 +95,13 @@ def _token_endpoint_request_no_throw(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
kwargs: Additional arguments passed on to the request method. The
kwargs will be passed to `requests.request` method, see:
https://docs.python-requests.org/en/latest/api/#requests.request.
For example, you can use `cert=("cert_pem_path", "key_pem_path")`
to set up client side SSL certificate, and use
`verify="ca_bundle_path"` to set up the CA certificates for sever
side SSL certificate verification.
Returns:
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
Expand All @@ -112,32 +121,40 @@ def _token_endpoint_request_no_throw(
# retry to fetch token for maximum of two times if any internal failure
# occurs.
while True:
response = request(method="POST", url=token_uri, headers=headers, body=body)
response = request(
method="POST", url=token_uri, headers=headers, body=body, **kwargs
)
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
else response.data
)
response_data = json.loads(response_body)

if response.status == http_client.OK:
# response_body should be a JSON
response_data = json.loads(response_body)
break
else:
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data
# For a failed response, response_body could be a string
try:
response_data = json.loads(response_body)
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
except ValueError:
response_data = response_body
return False, response_data

return True, response_data


def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
request, token_uri, body, access_token=None, use_json=False, **kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
Expand All @@ -150,6 +167,13 @@ def _token_endpoint_request(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
kwargs: Additional arguments passed on to the request method. The
kwargs will be passed to `requests.request` method, see:
https://docs.python-requests.org/en/latest/api/#requests.request.
For example, you can use `cert=("cert_pem_path", "key_pem_path")`
to set up client side SSL certificate, and use
`verify="ca_bundle_path"` to set up the CA certificates for sever
side SSL certificate verification.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Expand All @@ -159,7 +183,7 @@ def _token_endpoint_request(
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
request, token_uri, body, access_token=access_token, use_json=use_json, **kwargs
)
if not response_status_ok:
_handle_error_response(response_data)
Expand Down
Loading

0 comments on commit 94fb5e2

Please sign in to comment.