diff --git a/.gitignore b/.gitignore index ca0c0742c..a5c5f8414 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,6 @@ pytype_output/ .python-version .DS_Store cert_path -key_path \ No newline at end of file +key_path +env/ +.vscode/ \ No newline at end of file diff --git a/google/auth/transport/_custom_tls_signer.py b/google/auth/transport/_custom_tls_signer.py new file mode 100644 index 000000000..22a510daa --- /dev/null +++ b/google/auth/transport/_custom_tls_signer.py @@ -0,0 +1,235 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Code for configuring client side TLS to offload the signing operation to +signing libraries. +""" + +import ctypes +import json +import logging +import os +import sys + +import cffi # type: ignore +import six + +from google.auth import exceptions + +_LOGGER = logging.getLogger(__name__) + +# C++ offload lib requires google-auth lib to provide the following callback: +# using SignFunc = int (*)(unsigned char *sig, size_t *sig_len, +# const unsigned char *tbs, size_t tbs_len) +# The bytes to be signed and the length are provided via `tbs` and `tbs_len`, +# the callback computes the signature, and write the signature and its length +# into `sig` and `sig_len`. +# If the signing is successful, the callback returns 1, otherwise it returns 0. +SIGN_CALLBACK_CTYPE = ctypes.CFUNCTYPE( + ctypes.c_int, # return type + ctypes.POINTER(ctypes.c_ubyte), # sig + ctypes.POINTER(ctypes.c_size_t), # sig_len + ctypes.POINTER(ctypes.c_ubyte), # tbs + ctypes.c_size_t, # tbs_len +) + + +# Cast SSL_CTX* to void* +def _cast_ssl_ctx_to_void_p(ssl_ctx): + return ctypes.cast(int(cffi.FFI().cast("intptr_t", ssl_ctx)), ctypes.c_void_p) + + +# Load offload library and set up the function types. +def load_offload_lib(offload_lib_path): + _LOGGER.debug("loading offload library from %s", offload_lib_path) + + # winmode parameter is only available for python 3.8+. + lib = ( + ctypes.CDLL(offload_lib_path, winmode=0) + if sys.version_info >= (3, 8) and os.name == "nt" + else ctypes.CDLL(offload_lib_path) + ) + + # Set up types for: + # int ConfigureSslContext(SignFunc sign_func, const char *cert, SSL_CTX *ctx) + lib.ConfigureSslContext.argtypes = [ + SIGN_CALLBACK_CTYPE, + ctypes.c_char_p, + ctypes.c_void_p, + ] + lib.ConfigureSslContext.restype = ctypes.c_int + + return lib + + +# Load signer library and set up the function types. +# See: https://github.com/googleapis/enterprise-certificate-proxy/blob/main/cshared/main.go +def load_signer_lib(signer_lib_path): + _LOGGER.debug("loading signer library from %s", signer_lib_path) + + # winmode parameter is only available for python 3.8+. + lib = ( + ctypes.CDLL(signer_lib_path, winmode=0) + if sys.version_info >= (3, 8) and os.name == "nt" + else ctypes.CDLL(signer_lib_path) + ) + + # Set up types for: + # func GetCertPemForPython(configFilePath *C.char, certHolder *byte, certHolderLen int) + lib.GetCertPemForPython.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int] + # Returns: certLen + lib.GetCertPemForPython.restype = ctypes.c_int + + # Set up types for: + # func SignForPython(configFilePath *C.char, digest *byte, digestLen int, + # sigHolder *byte, sigHolderLen int) + lib.SignForPython.argtypes = [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_int, + ] + # Returns: the signature length + lib.SignForPython.restype = ctypes.c_int + + return lib + + +# Computes SHA256 hash. +def _compute_sha256_digest(to_be_signed, to_be_signed_len): + from cryptography.hazmat.primitives import hashes + + data = ctypes.string_at(to_be_signed, to_be_signed_len) + hash = hashes.Hash(hashes.SHA256()) + hash.update(data) + return hash.finalize() + + +# Create the signing callback. The actual signing work is done by the +# `SignForPython` method from the signer lib. +def get_sign_callback(signer_lib, config_file_path): + def sign_callback(sig, sig_len, tbs, tbs_len): + _LOGGER.debug("calling sign callback...") + + digest = _compute_sha256_digest(tbs, tbs_len) + digestArray = ctypes.c_char * len(digest) + + # reserve 2000 bytes for the signature, shoud be more then enough. + # RSA signature is 256 bytes, EC signature is 70~72. + sig_holder_len = 2000 + sig_holder = ctypes.create_string_buffer(sig_holder_len) + + signature_len = signer_lib.SignForPython( + config_file_path.encode(), # configFilePath + digestArray.from_buffer(bytearray(digest)), # digest + len(digest), # digestLen + sig_holder, # sigHolder + sig_holder_len, # sigHolderLen + ) + + if signature_len == 0: + # signing failed, return 0 + return 0 + + sig_len[0] = signature_len + bs = bytearray(sig_holder) + for i in range(signature_len): + sig[i] = bs[i] + + return 1 + + return SIGN_CALLBACK_CTYPE(sign_callback) + + +# Obtain the certificate bytes by calling the `GetCertPemForPython` method from +# the signer lib. The method is called twice, the first time is to compute the +# cert length, then we create a buffer to hold the cert, and call it again to +# fill the buffer. +def get_cert(signer_lib, config_file_path): + # First call to calculate the cert length + cert_len = signer_lib.GetCertPemForPython( + config_file_path.encode(), # configFilePath + None, # certHolder + 0, # certHolderLen + ) + if cert_len == 0: + raise exceptions.MutualTLSChannelError("failed to get certificate") + + # Then we create an array to hold the cert, and call again to fill the cert + cert_holder = ctypes.create_string_buffer(cert_len) + signer_lib.GetCertPemForPython( + config_file_path.encode(), # configFilePath + cert_holder, # certHolder + cert_len, # certHolderLen + ) + return bytes(cert_holder) + + +class CustomTlsSigner(object): + def __init__(self, enterprise_cert_file_path): + """ + This class loads the offload and signer library, and calls APIs from + these libraries to obtain the cert and a signing callback, and attach + them to SSL context. The cert and the signing callback will be used + for client authentication in TLS handshake. + + Args: + enterprise_cert_file_path (str): the path to a enterprise cert JSON + file. The file should contain the following field: + + { + "libs": { + "signer_library": "...", + "offload_library": "..." + } + } + """ + self._enterprise_cert_file_path = enterprise_cert_file_path + self._cert = None + self._sign_callback = None + + def load_libraries(self): + try: + with open(self._enterprise_cert_file_path, "r") as f: + enterprise_cert_json = json.load(f) + libs = enterprise_cert_json["libs"] + signer_library = libs["signer_library"] + offload_library = libs["offload_library"] + except (KeyError, ValueError) as caught_exc: + new_exc = exceptions.MutualTLSChannelError( + "enterprise cert file is invalid", caught_exc + ) + six.raise_from(new_exc, caught_exc) + self._offload_lib = load_offload_lib(offload_library) + self._signer_lib = load_signer_lib(signer_library) + + def set_up_custom_key(self): + # We need to keep a reference of the cert and sign callback so it won't + # be garbage collected, otherwise it will crash when used by signer lib. + self._cert = get_cert(self._signer_lib, self._enterprise_cert_file_path) + self._sign_callback = get_sign_callback( + self._signer_lib, self._enterprise_cert_file_path + ) + + def attach_to_ssl_context(self, ctx): + # In the TLS handshake, the signing operation will be done by the + # sign_callback. + if not self._offload_lib.ConfigureSslContext( + self._sign_callback, + ctypes.c_char_p(self._cert), + _cast_ssl_ctx_to_void_p(ctx._ctx._context), + ): + raise exceptions.MutualTLSChannelError("failed to configure SSL context") diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index f4a7c65d7..2c746f8d7 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -245,6 +245,65 @@ def proxy_manager_for(self, *args, **kwargs): return super(_MutualTlsAdapter, self).proxy_manager_for(*args, **kwargs) +class _MutualTlsOffloadAdapter(requests.adapters.HTTPAdapter): + """ + A TransportAdapter that enables mutual TLS and offloads the client side + signing operation to the signing library. + + Args: + enterprise_cert_file_path (str): the path to a enterprise cert JSON + file. The file should contain the following field: + + { + "libs": { + "signer_library": "...", + "offload_library": "..." + } + } + + Raises: + ImportError: if certifi or pyOpenSSL is not installed + google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel + creation failed for any reason. + """ + + def __init__(self, enterprise_cert_file_path): + import certifi + import urllib3.contrib.pyopenssl + + from google.auth.transport import _custom_tls_signer + + # Call inject_into_urllib3 to activate certificate checking. See the + # following links for more info: + # (1) doc: https://github.com/urllib3/urllib3/blob/cb9ebf8aac5d75f64c8551820d760b72b619beff/src/urllib3/contrib/pyopenssl.py#L31-L32 + # (2) mTLS example: https://github.com/urllib3/urllib3/issues/474#issuecomment-253168415 + urllib3.contrib.pyopenssl.inject_into_urllib3() + + self.signer = _custom_tls_signer.CustomTlsSigner(enterprise_cert_file_path) + self.signer.load_libraries() + self.signer.set_up_custom_key() + + poolmanager = create_urllib3_context() + poolmanager.load_verify_locations(cafile=certifi.where()) + self.signer.attach_to_ssl_context(poolmanager) + self._ctx_poolmanager = poolmanager + + proxymanager = create_urllib3_context() + proxymanager.load_verify_locations(cafile=certifi.where()) + self.signer.attach_to_ssl_context(proxymanager) + self._ctx_proxymanager = proxymanager + + super(_MutualTlsOffloadAdapter, self).__init__() + + def init_poolmanager(self, *args, **kwargs): + kwargs["ssl_context"] = self._ctx_poolmanager + super(_MutualTlsOffloadAdapter, self).init_poolmanager(*args, **kwargs) + + def proxy_manager_for(self, *args, **kwargs): + kwargs["ssl_context"] = self._ctx_proxymanager + return super(_MutualTlsOffloadAdapter, self).proxy_manager_for(*args, **kwargs) + + class AuthorizedSession(requests.Session): """A Requests Session class with credentials. diff --git a/noxfile.py b/noxfile.py index 937d35d69..7221315b0 100644 --- a/noxfile.py +++ b/noxfile.py @@ -118,6 +118,7 @@ def unit_prev_versions(session): "--cov=google.oauth2", "--cov=tests", "tests", + "--ignore=tests/transport/test__custom_tls_signer.py", # enterprise cert is for python 3.6+ ) diff --git a/setup.py b/setup.py index 22f627b99..1fce9743e 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,9 @@ ], "pyopenssl": "pyopenssl>=20.0.0", "reauth": "pyu2f>=0.1.5", + # Enterprise cert only works for OpenSSL 1.1.1. Newer versions of these + # dependencies are built with OpenSSL 3.0 so we need to fix the version. + "enterprise_cert": ["cryptography==36.0.2", "pyopenssl==22.0.0"], } with io.open("README.rst", "r") as fh: diff --git a/tests/data/enterprise_cert_invalid.json b/tests/data/enterprise_cert_invalid.json new file mode 100644 index 000000000..4715a590a --- /dev/null +++ b/tests/data/enterprise_cert_invalid.json @@ -0,0 +1,3 @@ +{ + "libs": {} +} \ No newline at end of file diff --git a/tests/data/enterprise_cert_valid.json b/tests/data/enterprise_cert_valid.json new file mode 100644 index 000000000..de4b2d009 --- /dev/null +++ b/tests/data/enterprise_cert_valid.json @@ -0,0 +1,6 @@ +{ + "libs": { + "signer_library": "/path/to/signer/lib", + "offload_library": "/path/to/offload/lib" + } +} \ No newline at end of file diff --git a/tests/transport/test__custom_tls_signer.py b/tests/transport/test__custom_tls_signer.py new file mode 100644 index 000000000..5836b325a --- /dev/null +++ b/tests/transport/test__custom_tls_signer.py @@ -0,0 +1,234 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import ctypes +import os + +import mock +import pytest # type: ignore +from requests.packages.urllib3.util.ssl_ import create_urllib3_context # type: ignore +import urllib3.contrib.pyopenssl # type: ignore + +from google.auth import exceptions +from google.auth.transport import _custom_tls_signer + +urllib3.contrib.pyopenssl.inject_into_urllib3() + +FAKE_ENTERPRISE_CERT_FILE_PATH = "/path/to/enterprise/cert/file" +ENTERPRISE_CERT_FILE = os.path.join( + os.path.dirname(__file__), "../data/enterprise_cert_valid.json" +) +INVALID_ENTERPRISE_CERT_FILE = os.path.join( + os.path.dirname(__file__), "../data/enterprise_cert_invalid.json" +) + + +def test_load_offload_lib(): + with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()): + lib = _custom_tls_signer.load_offload_lib("/path/to/offload/lib") + + assert lib.ConfigureSslContext.argtypes == [ + _custom_tls_signer.SIGN_CALLBACK_CTYPE, + ctypes.c_char_p, + ctypes.c_void_p, + ] + assert lib.ConfigureSslContext.restype == ctypes.c_int + + +def test_load_signer_lib(): + with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()): + lib = _custom_tls_signer.load_signer_lib("/path/to/signer/lib") + + assert lib.SignForPython.restype == ctypes.c_int + assert lib.SignForPython.argtypes == [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_int, + ] + + assert lib.GetCertPemForPython.restype == ctypes.c_int + assert lib.GetCertPemForPython.argtypes == [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_int, + ] + + +def test__compute_sha256_digest(): + to_be_signed = ctypes.create_string_buffer(b"foo") + sig = _custom_tls_signer._compute_sha256_digest(to_be_signed, 4) + + assert ( + base64.b64encode(sig).decode() == "RG5gyEH8CAAh3lxgbt2PLPAHPO8p6i9+cn5dqHfUUYM=" + ) + + +def test_get_sign_callback(): + # mock signer lib's SignForPython function + mock_sig_len = 10 + mock_signer_lib = mock.MagicMock() + mock_signer_lib.SignForPython.return_value = mock_sig_len + + # create a sign callback. The callback calls signer lib's SignForPython method + sign_callback = _custom_tls_signer.get_sign_callback( + mock_signer_lib, FAKE_ENTERPRISE_CERT_FILE_PATH + ) + + # mock the parameters used to call the sign callback + to_be_signed = ctypes.POINTER(ctypes.c_ubyte)() + to_be_signed_len = 4 + returned_sig_array = ctypes.c_ubyte() + mock_sig_array = ctypes.byref(returned_sig_array) + returned_sign_len = ctypes.c_ulong() + mock_sig_len_array = ctypes.byref(returned_sign_len) + + # call the callback, make sure the signature len is returned via mock_sig_len_array[0] + assert sign_callback( + mock_sig_array, mock_sig_len_array, to_be_signed, to_be_signed_len + ) + assert returned_sign_len.value == mock_sig_len + + +def test_get_sign_callback_failed_to_sign(): + # mock signer lib's SignForPython function. Set the sig len to be 0 to + # indicate the signing failed. + mock_sig_len = 0 + mock_signer_lib = mock.MagicMock() + mock_signer_lib.SignForPython.return_value = mock_sig_len + + # create a sign callback. The callback calls signer lib's SignForPython method + sign_callback = _custom_tls_signer.get_sign_callback( + mock_signer_lib, FAKE_ENTERPRISE_CERT_FILE_PATH + ) + + # mock the parameters used to call the sign callback + to_be_signed = ctypes.POINTER(ctypes.c_ubyte)() + to_be_signed_len = 4 + returned_sig_array = ctypes.c_ubyte() + mock_sig_array = ctypes.byref(returned_sig_array) + returned_sign_len = ctypes.c_ulong() + mock_sig_len_array = ctypes.byref(returned_sign_len) + sign_callback(mock_sig_array, mock_sig_len_array, to_be_signed, to_be_signed_len) + + # sign callback should return 0 + assert not sign_callback( + mock_sig_array, mock_sig_len_array, to_be_signed, to_be_signed_len + ) + + +def test_get_cert_no_cert(): + # mock signer lib's GetCertPemForPython function to return 0 to indicts + # the cert doesn't exit (cert len = 0) + mock_signer_lib = mock.MagicMock() + mock_signer_lib.GetCertPemForPython.return_value = 0 + + # call the get cert method + with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: + _custom_tls_signer.get_cert(mock_signer_lib, FAKE_ENTERPRISE_CERT_FILE_PATH) + + assert excinfo.match("failed to get certificate") + + +def test_get_cert(): + # mock signer lib's GetCertPemForPython function + mock_cert_len = 10 + mock_signer_lib = mock.MagicMock() + mock_signer_lib.GetCertPemForPython.return_value = mock_cert_len + + # call the get cert method + mock_cert = _custom_tls_signer.get_cert( + mock_signer_lib, FAKE_ENTERPRISE_CERT_FILE_PATH + ) + + # make sure the signer lib's GetCertPemForPython is called twice, and the + # mock_cert has length mock_cert_len + assert mock_signer_lib.GetCertPemForPython.call_count == 2 + assert len(mock_cert) == mock_cert_len + + +def test_custom_tls_signer(): + offload_lib = mock.MagicMock() + signer_lib = mock.MagicMock() + + # Test load_libraries method + with mock.patch( + "google.auth.transport._custom_tls_signer.load_signer_lib" + ) as load_signer_lib: + with mock.patch( + "google.auth.transport._custom_tls_signer.load_offload_lib" + ) as load_offload_lib: + load_offload_lib.return_value = offload_lib + load_signer_lib.return_value = signer_lib + signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE) + signer_object.load_libraries() + assert signer_object._cert is None + assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE + assert signer_object._offload_lib == offload_lib + assert signer_object._signer_lib == signer_lib + load_signer_lib.assert_called_with("/path/to/signer/lib") + load_offload_lib.assert_called_with("/path/to/offload/lib") + + # Test set_up_custom_key and set_up_ssl_context methods + with mock.patch("google.auth.transport._custom_tls_signer.get_cert") as get_cert: + with mock.patch( + "google.auth.transport._custom_tls_signer.get_sign_callback" + ) as get_sign_callback: + get_cert.return_value = b"mock_cert" + signer_object.set_up_custom_key() + signer_object.attach_to_ssl_context(create_urllib3_context()) + get_cert.assert_called_once() + get_sign_callback.assert_called_once() + offload_lib.ConfigureSslContext.assert_called_once() + + +def test_custom_tls_signer_failed_to_load_libraries(): + # Test load_libraries method + with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: + signer_object = _custom_tls_signer.CustomTlsSigner(INVALID_ENTERPRISE_CERT_FILE) + signer_object.load_libraries() + assert excinfo.match("enterprise cert file is invalid") + + +def test_custom_tls_signer_fail_to_offload(): + offload_lib = mock.MagicMock() + signer_lib = mock.MagicMock() + + with mock.patch( + "google.auth.transport._custom_tls_signer.load_signer_lib" + ) as load_signer_lib: + with mock.patch( + "google.auth.transport._custom_tls_signer.load_offload_lib" + ) as load_offload_lib: + load_offload_lib.return_value = offload_lib + load_signer_lib.return_value = signer_lib + signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE) + signer_object.load_libraries() + + # set the return value to be 0 which indicts offload fails + offload_lib.ConfigureSslContext.return_value = 0 + + with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: + with mock.patch( + "google.auth.transport._custom_tls_signer.get_cert" + ) as get_cert: + with mock.patch( + "google.auth.transport._custom_tls_signer.get_sign_callback" + ): + get_cert.return_value = b"mock_cert" + signer_object.set_up_custom_key() + signer_object.attach_to_ssl_context(create_urllib3_context()) + assert excinfo.match("failed to configure SSL context") diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index 7899419be..9532117b5 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -28,6 +28,7 @@ from google.auth import environment_vars from google.auth import exceptions import google.auth.credentials +import google.auth.transport._custom_tls_signer import google.auth.transport._mtls_helper import google.auth.transport.requests from google.oauth2 import service_account @@ -535,3 +536,40 @@ def test_close_w_passed_in_auth_request(self): ) authed_session.close() # no raise + + +class TestMutualTlsOffloadAdapter(object): + @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager") + @mock.patch.object(requests.adapters.HTTPAdapter, "proxy_manager_for") + @mock.patch.object( + google.auth.transport._custom_tls_signer.CustomTlsSigner, "load_libraries" + ) + @mock.patch.object( + google.auth.transport._custom_tls_signer.CustomTlsSigner, "set_up_custom_key" + ) + @mock.patch.object( + google.auth.transport._custom_tls_signer.CustomTlsSigner, + "attach_to_ssl_context", + ) + def test_success( + self, + mock_attach_to_ssl_context, + mock_set_up_custom_key, + mock_load_libraries, + mock_proxy_manager_for, + mock_init_poolmanager, + ): + enterprise_cert_file_path = "/path/to/enterprise/cert/json" + adapter = google.auth.transport.requests._MutualTlsOffloadAdapter( + enterprise_cert_file_path + ) + + mock_load_libraries.assert_called_once() + mock_set_up_custom_key.assert_called_once() + assert mock_attach_to_ssl_context.call_count == 2 + + adapter.init_poolmanager() + mock_init_poolmanager.assert_called_with(ssl_context=adapter._ctx_poolmanager) + + adapter.proxy_manager_for() + mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager)