diff --git a/dashboard/agent.py b/dashboard/agent.py index 2505d06a78a60..9b199f20a93a7 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -83,14 +83,14 @@ def __init__(self, assert self.ppid > 0 logger.info("Parent pid is %s", self.ppid) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = self.server.add_insecure_port( - f"[::]:{self.dashboard_agent_port}") + self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( + self.server, f"[::]:{self.dashboard_agent_port}") logger.info("Dashboard agent grpc address: %s:%s", self.ip, self.grpc_port) self.aioredis_client = None options = (("grpc.enable_http_proxy", 0), ) - self.aiogrpc_raylet_channel = aiogrpc.insecure_channel( - f"{self.ip}:{self.node_manager_port}", options=options) + self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel( + f"{self.ip}:{self.node_manager_port}", options, asynchronous=True) self.http_session = None ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) diff --git a/dashboard/head.py b/dashboard/head.py index 467d6ca6a76f8..fc7b12d3a6c27 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -12,6 +12,7 @@ from grpc.experimental import aio as aiogrpc import grpc +import ray._private.utils import ray._private.services import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils @@ -56,7 +57,7 @@ async def get_gcs_address_with_retry(redis_client) -> str: class GCSHealthCheckThread(threading.Thread): def __init__(self, gcs_address: str): - self.grpc_gcs_channel = grpc.insecure_channel( + self.grpc_gcs_channel = ray._private.utils.init_grpc_channel( gcs_address, options=GRPC_CHANNEL_OPTIONS) self.gcs_heartbeat_info_stub = ( gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub( @@ -116,7 +117,8 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address, ip, port = redis_address.split(":") self.gcs_client = connect_to_gcs(ip, int(port), redis_password) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - self.grpc_port = self.server.add_insecure_port("[::]:0") + self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( + self.server, "[::]:0") logger.info("Dashboard head grpc address: %s:%s", self.ip, self.grpc_port) @@ -188,8 +190,8 @@ async def run(self): # Waiting for GCS is ready. gcs_address = await get_gcs_address_with_retry(self.aioredis_client) - self.aiogrpc_gcs_channel = aiogrpc.insecure_channel( - gcs_address, options=GRPC_CHANNEL_OPTIONS) + self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel( + gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True) self.health_check_thread = GCSHealthCheckThread(gcs_address) self.health_check_thread.start() diff --git a/dashboard/modules/actor/actor_head.py b/dashboard/modules/actor/actor_head.py index b5d6757f4f07b..c05f61ea55ace 100644 --- a/dashboard/modules/actor/actor_head.py +++ b/dashboard/modules/actor/actor_head.py @@ -51,7 +51,8 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel(address, options=options) + channel = ray._private.utils.init_grpc_channel( + address, options, asynchronous=True) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub @@ -180,8 +181,8 @@ async def kill_actor(self, req) -> aiohttp.web.Response: return rest_response(success=False, message="Bad Request") try: options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel( - f"{ip_address}:{port}", options=options) + channel = ray._private.utils.init_grpc_channel( + f"{ip_address}:{port}", options=options, asynchronous=True) stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel) await stub.KillActor( diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index 1d99eafbe27c2..2740c19d70549 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -2,8 +2,8 @@ import asyncio import logging from typing import Union -from grpc.experimental import aio as aiogrpc +import ray._private.utils as utils import ray.dashboard.utils as dashboard_utils import ray.dashboard.consts as dashboard_consts from ray.ray_constants import env_bool @@ -46,8 +46,10 @@ async def _connect_to_dashboard(self): if dashboard_rpc_address: logger.info("Report events to %s", dashboard_rpc_address) options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel( - dashboard_rpc_address, options=options) + channel = utils.init_grpc_channel( + dashboard_rpc_address, + options=options, + asynchronous=True) return event_pb2_grpc.ReportEventServiceStub(channel) except Exception: logger.exception("Connect to dashboard failed.") diff --git a/dashboard/modules/job/job_head.py b/dashboard/modules/job/job_head.py index c573d5ec46357..a4f396431a1e4 100644 --- a/dashboard/modules/job/job_head.py +++ b/dashboard/modules/job/job_head.py @@ -4,8 +4,8 @@ import aiohttp.web from aioredis.pubsub import Receiver -from grpc.experimental import aio as aiogrpc +import ray._private.utils import ray._private.gcs_utils as gcs_utils import ray.dashboard.utils as dashboard_utils from ray.dashboard.modules.job import job_consts @@ -52,7 +52,9 @@ async def submit_job(self, req) -> aiohttp.web.Response: ip = DataSource.node_id_to_ip[node_id] address = f"{ip}:{ports[1]}" options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel(address, options=options) + channel = ray._private.utils.init_grpc_channel( + address, options, asynchronous=True) + stub = job_agent_pb2_grpc.JobAgentServiceStub(channel) request = job_agent_pb2.InitializeJobEnvRequest( job_description=json.dumps(job_description_data)) diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index ce7fd2363ca58..2f532c97ab4bf 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -4,7 +4,6 @@ import json import aiohttp.web from aioredis.pubsub import Receiver -from grpc.experimental import aio as aiogrpc import ray._private.utils import ray._private.gcs_utils as gcs_utils @@ -68,7 +67,8 @@ async def _update_stubs(self, change): address = "{}:{}".format(node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])) options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel(address, options=options) + channel = ray._private.utils.init_grpc_channel( + address, options, asynchronous=True) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 8200b7df321f8..beebb29cbfdf9 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -4,7 +4,6 @@ import os import aiohttp.web from aioredis.pubsub import Receiver -from grpc.experimental import aio as aiogrpc import ray import ray.dashboard.modules.reporter.reporter_consts as reporter_consts @@ -38,8 +37,8 @@ async def _update_stubs(self, change): node_id, ports = change.new ip = DataSource.node_id_to_ip[node_id] options = (("grpc.enable_http_proxy", 0), ) - channel = aiogrpc.insecure_channel( - f"{ip}:{ports[1]}", options=options) + channel = ray._private.utils.init_grpc_channel( + f"{ip}:{ports[1]}", options=options, asynchronous=True) stub = reporter_pb2_grpc.ReporterServiceStub(channel) self._stubs[ip] = stub diff --git a/dashboard/utils.py b/dashboard/utils.py index d957563937838..c6175d9fa58d8 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -17,7 +17,6 @@ from collections import namedtuple from collections.abc import MutableMapping, Mapping, Sequence from typing import Any - from google.protobuf.json_format import MessageToDict import ray.dashboard.consts as dashboard_consts diff --git a/doc/source/configure.rst b/doc/source/configure.rst index 5e93b2c6e4f82..186255d855373 100644 --- a/doc/source/configure.rst +++ b/doc/source/configure.rst @@ -234,6 +234,28 @@ to localhost when the ray is started using ``ray.init``. See the `Redis security documentation `__ for more information. +TLS Authentication +------------------ + +Ray can be configured to use TLS on it's gRPC channels. +This has means that connecting to the Ray client on the head node will +require an appropriate set of credentials and also that data exchanged between +various processes (client, head, workers) will be encrypted. + +Enabling TLS will cause a performance hit due to the extra overhead of mutual +authentication and encryption. +Testing has shown that this overhead is large for small workloads and becomes +relatively smaller for large workloads. +The exact overhead will depend on the nature of your workload. + +TLS is enabled by setting environment variables. + +- ``RAY_USE_TLS``: Either 1 or 0 to use/not-use TLS. If this is set to 1 then all of the environment variables below must be set. Default: 0. +- ``RAY_TLS_SERVER_CERT``: Location of a `certificate file` which is presented to other endpoints so as to achieve mutual authentication. +- ``RAY_TLS_SERVER_KEY``: Location of a `private key file` which is the cryptographic means to prove to other endpoints that you are the authorized user of a given certificate. +- ``RAY_TLS_CA_CERT``: Location of a `CA certificate file` which allows TLS to decide whether an endpoint's certificate has been signed by the correct authority. + + Java Applications ----------------- diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 2c410286d1d5c..e4fb68dd9ec67 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -15,11 +15,14 @@ from typing import Optional, Any, List, Dict from contextlib import redirect_stdout, redirect_stderr import yaml +import pytest +import tempfile import ray import ray._private.services import ray._private.utils import ray._private.gcs_utils as gcs_utils +from ray._private.tls_utils import generate_self_signed_tls_certs from ray.util.queue import Queue, _QueueActor, Empty from ray.scripts.scripts import main as ray_main try: @@ -691,3 +694,34 @@ def is_placement_group_removed(pg): if "state" not in table: return False return table["state"] == "REMOVED" + + +def setup_tls(): + """Sets up required environment variables for tls""" + if sys.platform == "darwin": + pytest.skip("Cryptography doesn't install in Mac build pipeline") + cert, key = generate_self_signed_tls_certs() + temp_dir = tempfile.mkdtemp("ray-test-certs") + cert_filepath = os.path.join(temp_dir, "server.crt") + key_filepath = os.path.join(temp_dir, "server.key") + with open(cert_filepath, "w") as fh: + fh.write(cert) + with open(key_filepath, "w") as fh: + fh.write(key) + + os.environ["RAY_USE_TLS"] = "1" + os.environ["RAY_TLS_SERVER_CERT"] = cert_filepath + os.environ["RAY_TLS_SERVER_KEY"] = key_filepath + os.environ["RAY_TLS_CA_CERT"] = cert_filepath + + return key_filepath, cert_filepath, temp_dir + + +def teardown_tls(key_filepath, cert_filepath, temp_dir): + os.remove(key_filepath) + os.remove(cert_filepath) + os.removedirs(temp_dir) + del os.environ["RAY_USE_TLS"] + del os.environ["RAY_TLS_SERVER_CERT"] + del os.environ["RAY_TLS_SERVER_KEY"] + del os.environ["RAY_TLS_CA_CERT"] diff --git a/python/ray/_private/tls_utils.py b/python/ray/_private/tls_utils.py new file mode 100644 index 0000000000000..0e4746201bc73 --- /dev/null +++ b/python/ray/_private/tls_utils.py @@ -0,0 +1,90 @@ +import datetime +import os +import socket + +import grpc + + +def generate_self_signed_tls_certs(): + """Create self-signed key/cert pair for testing. + + This method requires the library ``cryptography`` be installed. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + raise ImportError( + "Using `Security.temporary` requires `cryptography`, please " + "install it using either pip or conda") + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend()) + key_contents = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + ray_interal = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) + # This is the same logic used by the GCS server to acquire a + # private/interal IP address to listen on. If we just use localhost + + # 127.0.0.1 then we won't be able to connect to the GCS and will get + # an error like "No match found for server name: 192.168.X.Y" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + private_ip_address = s.getsockname()[0] + s.close() + altnames = x509.SubjectAlternativeName([ + x509.DNSName(socket.gethostbyname( + socket.gethostname())), # Probably 127.0.0.1 + x509.DNSName("127.0.0.1"), + x509.DNSName(private_ip_address), # 192.168.*.* + x509.DNSName("localhost"), + ]) + now = datetime.datetime.utcnow() + cert = (x509.CertificateBuilder().subject_name(ray_interal).issuer_name( + ray_interal).add_extension(altnames, critical=False).public_key( + key.public_key()).serial_number( + x509.random_serial_number()).not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)).sign( + key, hashes.SHA256(), default_backend())) + + cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() + + return cert_contents, key_contents + + +def add_port_to_grpc_server(server, address): + if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + server_cert_chain, private_key, ca_cert = load_certs_from_env() + credentials = grpc.ssl_server_credentials( + [(private_key, server_cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None) + return server.add_secure_port(address, credentials) + else: + return server.add_insecure_port(address) + + +def load_certs_from_env(): + tls_env_vars = [ + "RAY_TLS_SERVER_CERT", "RAY_TLS_SERVER_KEY", "RAY_TLS_CA_CERT" + ] + if any(v not in os.environ for v in tls_env_vars): + raise RuntimeError( + "If the environment variable RAY_USE_TLS is set to true " + "then RAY_TLS_SERVER_CERT, RAY_TLS_SERVER_KEY and " + "RAY_TLS_CA_CERT must also be set.") + + with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f: + server_cert_chain = f.read() + with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f: + private_key = f.read() + with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f: + ca_cert = f.read() + + return server_cert_chain, private_key, ca_cert diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index e03784dc522a5..b67734bc9e0fd 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -13,9 +13,11 @@ import tempfile import threading import time -from typing import Optional +from typing import Optional, Sequence, Tuple, Any import uuid +import grpc import warnings +from grpc.experimental import aio as aiogrpc import inspect from inspect import signature @@ -25,6 +27,7 @@ import ray import ray._private.gcs_utils as gcs_utils import ray.ray_constants as ray_constants +from ray._private.tls_utils import load_certs_from_env # Import psutil after ray so the packaged version is used. import psutil @@ -1109,6 +1112,24 @@ def validate_namespace(namespace: str): "Pass None to not specify a namespace.") +def init_grpc_channel(address: str, + options: Optional[Sequence[Tuple[str, Any]]] = None, + asynchronous: bool = False): + grpc_module = aiogrpc if asynchronous else grpc + if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + server_cert_chain, private_key, ca_cert = load_certs_from_env() + credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert) + channel = grpc_module.secure_channel( + address, credentials, options=options) + else: + channel = grpc_module.insecure_channel(address, options=options) + + return channel + + def check_dashboard_dependencies_installed() -> bool: """Returns True if Ray Dashboard dependencies are installed. diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index c291558968aa2..b19a8d04c7032 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -12,8 +12,6 @@ from multiprocessing.synchronize import Event from typing import Optional -import grpc - try: import prometheus_client except ImportError: @@ -40,6 +38,7 @@ from ray.experimental.internal_kv import _internal_kv_put, \ _internal_kv_initialized, _internal_kv_get, _internal_kv_del from ray._raylet import connect_to_gcs, disconnect_from_gcs +import ray._private.utils logger = logging.getLogger(__name__) @@ -151,9 +150,9 @@ def __init__(self, self.gcs_client = connect_to_gcs(ip, int(port), redis_password) # Initialize the gcs stub for getting all node resource usage. gcs_address = self.redis.get("GcsServerAddress").decode("utf-8") - options = (("grpc.enable_http_proxy", 0), ) - gcs_channel = grpc.insecure_channel(gcs_address, options=options) + gcs_channel = ray._private.utils.init_grpc_channel( + gcs_address, options) self.gcs_node_resources_stub = \ gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(gcs_channel) diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index 304788eebb052..7df4016e1a982 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -2,6 +2,7 @@ import ray._private.services as services import ray.worker import ray._private.profiling as profiling +import ray._private.utils as utils from ray import ray_constants from ray.state import GlobalState @@ -41,7 +42,6 @@ def memory_summary(address=None, def get_store_stats(state, node_manager_address=None, node_manager_port=None): """Returns a formatted string describing memory usage in the cluster.""" - import grpc from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2_grpc @@ -60,13 +60,15 @@ def get_store_stats(state, node_manager_address=None, node_manager_port=None): else: raylet_address = "{}:{}".format(node_manager_address, node_manager_port) - channel = grpc.insecure_channel( + + channel = utils.init_grpc_channel( raylet_address, options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), ], ) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) reply = stub.FormatGlobalMemoryInfo( node_manager_pb2.FormatGlobalMemoryInfoRequest( @@ -80,20 +82,20 @@ def node_stats(node_manager_address=None, include_memory_info=True): """Returns NodeStats object describing memory usage in the cluster.""" - import grpc from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2_grpc # We can ask any Raylet for the global memory info. assert (node_manager_address is not None and node_manager_port is not None) raylet_address = "{}:{}".format(node_manager_address, node_manager_port) - channel = grpc.insecure_channel( + channel = utils.init_grpc_channel( raylet_address, options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), ], ) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) node_stats = stub.GetNodeStats( node_manager_pb2.GetNodeStatsRequest( diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index a575621e688e1..dec530cb3022b 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -16,7 +16,6 @@ import ray import psutil -import grpc import ray._private.services as services import ray.ray_constants as ray_constants import ray._private.utils @@ -1775,7 +1774,8 @@ def healthcheck(address, redis_password, component): try: gcs_address = redis_client.get("GcsServerAddress").decode("utf-8") options = (("grpc.enable_http_proxy", 0), ) - channel = grpc.insecure_channel(gcs_address, options=options) + channel = ray._private.utils.init_grpc_channel( + gcs_address, options) stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(channel) request = gcs_service_pb2.CheckAliveRequest() reply = stub.CheckAlive( diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index f9e0732df43b9..e7a62d4b4db50 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -93,6 +93,7 @@ py_test_module_list( "test_tempfile.py", "test_tensorflow.py", "test_threaded_actor.py", + "test_tls_auth.py", "test_ray_debugger.py", ], size = "medium", diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index a30e1d29b6667..62a2a81edbb88 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -6,11 +6,13 @@ import pytest import subprocess import json +import time import ray from ray.cluster_utils import Cluster from ray._private.services import REDIS_EXECUTABLE, _start_redis_instance -from ray._private.test_utils import init_error_pubsub +from ray._private.test_utils import init_error_pubsub, setup_tls, teardown_tls +import ray.util.client.server.server as ray_client_server import ray._private.gcs_utils as gcs_utils @@ -230,6 +232,14 @@ def call_ray_start_with_external_redis(request): subprocess.check_call(["ray", "stop"]) +@pytest.fixture +def init_and_serve(): + server_handle, _ = ray_client_server.init_and_serve("localhost:50051") + yield server_handle + ray_client_server.shutdown_with_server(server_handle.grpc_server) + time.sleep(2) + + @pytest.fixture def call_ray_stop_only(): yield @@ -287,6 +297,15 @@ def log_pubsub(): p.close() +@pytest.fixture +def use_tls(request): + if request.param: + key_filepath, cert_filepath, temp_dir = setup_tls() + yield request.param + if request.param: + teardown_tls(key_filepath, cert_filepath, temp_dir) + + """ Object spilling test fixture """ diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py index 74c4cca200fea..a474d88ebe724 100644 --- a/python/ray/tests/test_client_init.py +++ b/python/ray/tests/test_client_init.py @@ -40,14 +40,6 @@ def get(self): return self.val -@pytest.fixture -def init_and_serve(): - server_handle, _ = ray_client_server.init_and_serve("localhost:50051") - yield server_handle - ray_client_server.shutdown_with_server(server_handle.grpc_server) - time.sleep(2) - - @pytest.fixture def init_and_serve_lazy(): cluster = ray.cluster_utils.Cluster() diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py index a0813828e0797..fc83c9b32bce9 100644 --- a/python/ray/tests/test_metrics.py +++ b/python/ray/tests/test_metrics.py @@ -9,6 +9,7 @@ from ray.core.generated import node_manager_pb2_grpc from ray._private.test_utils import (RayTestTimeoutException, wait_until_succeeded_without_exception) +from ray._private.utils import init_grpc_channel import psutil # We must import psutil after ray because we bundle it with ray. @@ -20,7 +21,7 @@ def test_worker_stats(shutdown_only): raylet_address = "{}:{}".format(raylet["NodeManagerAddress"], ray.nodes()[0]["NodeManagerPort"]) - channel = grpc.insecure_channel(raylet_address) + channel = init_grpc_channel(raylet_address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) def try_get_node_stats(num_retry=5, timeout=2): diff --git a/python/ray/tests/test_multi_tenancy.py b/python/ray/tests/test_multi_tenancy.py index c9b16b3db6518..f2913a50c05ba 100644 --- a/python/ray/tests/test_multi_tenancy.py +++ b/python/ray/tests/test_multi_tenancy.py @@ -3,7 +3,6 @@ import sys import time -import grpc import pytest import numpy as np @@ -12,13 +11,14 @@ from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc from ray._private.test_utils import (wait_for_condition, run_string_as_driver, run_string_as_driver_nonblocking) +from ray._private.utils import init_grpc_channel def get_workers(): raylet = ray.nodes()[0] raylet_address = "{}:{}".format(raylet["NodeManagerAddress"], raylet["NodeManagerPort"]) - channel = grpc.insecure_channel(raylet_address) + channel = init_grpc_channel(raylet_address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) return [ worker for worker in stub.GetNodeStats( diff --git a/python/ray/tests/test_tls_auth.py b/python/ray/tests/test_tls_auth.py new file mode 100644 index 0000000000000..9b3d418c70d19 --- /dev/null +++ b/python/ray/tests/test_tls_auth.py @@ -0,0 +1,153 @@ +# coding: utf-8 +import logging +import os +import sys +import subprocess + +import pytest + +from ray._private.test_utils import run_string_as_driver + +logger = logging.getLogger(__name__) + + +def build_env(): + env = os.environ.copy() + if sys.platform == "win32" and "SYSTEMROOT" not in env: + env["SYSTEMROOT"] = r"C:\Windows" + + return env + + +@pytest.mark.skipif( + sys.platform == "darwin", + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_init_with_tls(use_tls): + # Run as a new process to pick up environment variables set + # in the use_tls fixture + run_string_as_driver( + """ +import ray +try: + ray.init() +finally: + ray.shutdown() + """, + env=build_env()) + + +@pytest.mark.skipif( + sys.platform == "darwin", + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_put_get_with_tls(use_tls): + run_string_as_driver( + """ +import ray +ray.init() +try: + for i in range(100): + value_before = i * 10**6 + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = i * 10**6 * 1.0 + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = "h" * i + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after + + for i in range(100): + value_before = [1] * i + object_ref = ray.put(value_before) + value_after = ray.get(object_ref) + assert value_before == value_after +finally: + ray.shutdown() + """, + env=build_env()) + + +@pytest.mark.skipif( + sys.platform == "darwin", + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) +@pytest.mark.parametrize("use_tls", [True], indirect=True, scope="module") +def test_submit_with_tls(use_tls): + run_string_as_driver( + """ +import ray +ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) + +@ray.remote +def f(n): + return list(range(n)) + +id1, id2, id3 = f._remote(args=[3], num_returns=3) +assert ray.get([id1, id2, id3]) == [0, 1, 2] + +@ray.remote +class Actor: + def __init__(self, x, y=0): + self.x = x + self.y = y + + def method(self, a, b=0): + return self.x, self.y, a, b + +a = Actor._remote( + args=[0], kwargs={"y": 1}, num_gpus=1, resources={"Custom": 1}) + +id1, id2, id3, id4 = a.method._remote( + args=["test"], kwargs={"b": 2}, num_returns=4) +assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] + """, + env=build_env()) + + +@pytest.mark.skipif( + sys.platform == "darwin", + reason=( + "Cryptography (TLS dependency) doesn't install in Mac build pipeline")) +@pytest.mark.parametrize("use_tls", [True], indirect=True) +def test_client_connect_to_tls_server(use_tls, call_ray_start): + tls_env = build_env() # use_tls fixture sets TLS environment variables + without_tls_env = {k: v for k, v in tls_env.items() if "TLS" not in k} + + # Attempt to connect without TLS + with pytest.raises(subprocess.CalledProcessError) as exc_info: + run_string_as_driver( + """ +from ray.util.client import ray as ray_client +ray_client.connect("localhost:10001") + """, + env=without_tls_env) + assert "ConnectionError" in exc_info.value.output.decode("utf-8") + + # Attempt to connect with TLS + out = run_string_as_driver( + """ +import ray +from ray.util.client import ray as ray_client +ray_client.connect("localhost:10001") +print(ray.is_initialized()) + """, + env=tls_env) + assert out == "True\n" + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 6baa22e6cbd43..0510ab3732c30 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -29,6 +29,7 @@ from ray._private.parameter import RayParams from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server +from ray._private.tls_utils import add_port_to_grpc_server from ray._private.utils import (detect_fate_sharing_support, check_dashboard_dependencies_installed) @@ -119,7 +120,7 @@ def __init__(self, self._free_ports: List[int] = list( range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT)) - self._runtime_env_channel = grpc.insecure_channel( + self._runtime_env_channel = ray._private.utils.init_grpc_channel( f"localhost:{runtime_env_agent_port}") self._runtime_env_stub = runtime_env_agent_pb2_grpc.RuntimeEnvServiceStub( # noqa: E501 self._runtime_env_channel) @@ -196,7 +197,7 @@ def create_specific_server(self, client_id: str) -> SpecificServer: server = SpecificServer( port=port, process_handle_future=futures.Future(), - channel=grpc.insecure_channel( + channel=ray._private.utils.init_grpc_channel( f"localhost:{port}", options=GRPC_OPTIONS)) self.servers[client_id] = server return server @@ -765,7 +766,7 @@ def serve_proxier(connection_str: str, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - server.add_insecure_port(connection_str) + add_port_to_grpc_server(server, connection_str) server.start() return ClientServerHandle( task_servicer=task_servicer, diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 81f97822248e9..27a10d18e3b11 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -35,6 +35,7 @@ from ray.ray_constants import env_integer from ray.util.placement_group import PlacementGroup from ray._private.client_mode_hook import disable_client_hook +from ray._private.tls_utils import add_port_to_grpc_server logger = logging.getLogger(__name__) @@ -686,7 +687,7 @@ def default_connect_handler(job_config: JobConfig = None, data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) - server.add_insecure_port(connection_str) + add_port_to_grpc_server(server, connection_str) current_handle = ClientServerHandle( task_servicer=task_servicer, data_servicer=data_servicer, diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index ec3e57d739572..5b441350cd67f 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -35,6 +35,7 @@ from ray.util.client.dataclient import DataClient from ray.util.client.logsclient import LogstreamClient from ray.util.debug import log_once +import ray._private.utils from ray._private.runtime_env.working_dir import upload_working_dir_if_needed if TYPE_CHECKING: @@ -100,7 +101,8 @@ def __init__( self.server = None self._conn_state = grpc.ChannelConnectivity.IDLE self._converted: Dict[str, ClientStub] = {} - self._secure = secure + self._secure = secure or os.environ.get("RAY_USE_TLS", + "0").lower() in ("1", "true") self._conn_str = conn_str self._connection_retries = connection_retries @@ -159,6 +161,13 @@ def _connect_channel(self, reconnecting=False) -> None: if self._secure: if self._credentials is not None: credentials = self._credentials + elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + server_cert_chain, private_key, ca_cert = ray._private.utils \ + .load_certs_from_env() + credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert) else: credentials = grpc.ssl_channel_credentials() self.channel = grpc.secure_channel( diff --git a/python/requirements.txt b/python/requirements.txt index 4177701a25d87..40922f30ccc16 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -85,3 +85,4 @@ smart_open[s3] tqdm async-exit-stack async-generator +cryptography>=3.0.0 diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index c005acbbbd146..bd69c4ea3bcf2 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -492,3 +492,11 @@ RAY_CONFIG(bool, scheduler_avoid_gpu_nodes, true) /// Whether to skip running local GC in runtime env. RAY_CONFIG(bool, runtime_env_skip_local_gc, false) + +/// Whether or not use TLS. +RAY_CONFIG(bool, USE_TLS, false) + +/// Location of TLS credentials +RAY_CONFIG(std::string, TLS_SERVER_CERT, "") +RAY_CONFIG(std::string, TLS_SERVER_KEY, "") +RAY_CONFIG(std::string, TLS_CA_CERT, "") diff --git a/src/ray/rpc/common.cc b/src/ray/rpc/common.cc new file mode 100644 index 0000000000000..7526c1e6efc6f --- /dev/null +++ b/src/ray/rpc/common.cc @@ -0,0 +1,29 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/common.h" + +#include +#include + +namespace ray::rpc { + +std::string ReadCert(const std::string &cert_filepath) { + std::ifstream t(cert_filepath); + std::stringstream buffer; + buffer << t.rdbuf(); + return buffer.str(); +}; + +} // namespace ray::rpc diff --git a/src/ray/rpc/common.h b/src/ray/rpc/common.h new file mode 100644 index 0000000000000..314e1eccf382c --- /dev/null +++ b/src/ray/rpc/common.h @@ -0,0 +1,22 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace ray::rpc { + +// Utility to read cert file from a particular location +std::string ReadCert(const std::string &cert_filepath); + +} // namespace ray::rpc diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 6ca3b1f47f68b..2670bc0674cde 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -22,6 +22,7 @@ #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/rpc/client_call.h" +#include "ray/rpc/common.h" namespace ray { namespace rpc { @@ -43,23 +44,24 @@ namespace rpc { template class GrpcClient { public: - GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager) - : client_call_manager_(call_manager) { + GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, + bool use_tls = false) + : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ChannelArguments argument; // Disable http proxy since it disrupts local connections. TODO(ekl) we should make // this configurable, or selectively set it for known local connections only. argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0); argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - std::shared_ptr channel = - grpc::CreateCustomChannel(address + ":" + std::to_string(port), - grpc::InsecureChannelCredentials(), argument); + + std::shared_ptr channel = BuildChannel(argument, address, port); + stub_ = GrpcService::NewStub(channel); } GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, - int num_threads) - : client_call_manager_(call_manager) { + int num_threads, bool use_tls = false) + : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ResourceQuota quota; quota.SetMaxThreads(num_threads); grpc::ChannelArguments argument; @@ -67,9 +69,9 @@ class GrpcClient { argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0); argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - std::shared_ptr channel = - grpc::CreateCustomChannel(address + ":" + std::to_string(port), - grpc::InsecureChannelCredentials(), argument); + + std::shared_ptr channel = BuildChannel(argument, address, port); + stub_ = GrpcService::NewStub(channel); } @@ -98,6 +100,34 @@ class GrpcClient { ClientCallManager &client_call_manager_; /// The gRPC-generated stub. std::unique_ptr stub_; + /// Whether to use TLS. + bool use_tls_; + + std::shared_ptr BuildChannel(const grpc::ChannelArguments &argument, + const std::string &address, int port) { + std::shared_ptr channel; + if (::RayConfig::instance().USE_TLS()) { + std::string server_cert_file = + std::string(::RayConfig::instance().TLS_SERVER_CERT()); + std::string server_key_file = std::string(::RayConfig::instance().TLS_SERVER_KEY()); + std::string root_cert_file = std::string(::RayConfig::instance().TLS_CA_CERT()); + std::string server_cert_chain = ReadCert(server_cert_file); + std::string private_key = ReadCert(server_key_file); + std::string cacert = ReadCert(root_cert_file); + + grpc::SslCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = cacert; + ssl_opts.pem_private_key = private_key; + ssl_opts.pem_cert_chain = server_cert_chain; + auto ssl_creds = grpc::SslCredentials(ssl_opts); + channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), ssl_creds, + argument); + } else { + channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), + grpc::InsecureChannelCredentials(), argument); + } + return channel; + }; }; } // namespace rpc diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 470f1851d3076..7c69bee606646 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -19,6 +19,7 @@ #include #include "ray/common/ray_config.h" +#include "ray/rpc/common.h" #include "ray/rpc/grpc_server.h" #include "ray/stats/metric.h" #include "ray/util/util.h" @@ -65,8 +66,24 @@ void GrpcServer::Run() { RayConfig::instance().grpc_keepalive_timeout_ms()); builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); - // TODO(hchen): Add options for authentication. - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + if (RayConfig::instance().USE_TLS()) { + // Create credentials from locations specified in config + std::string rootcert = ReadCert(RayConfig::instance().TLS_CA_CERT()); + std::string servercert = ReadCert(RayConfig::instance().TLS_SERVER_CERT()); + std::string serverkey = ReadCert(RayConfig::instance().TLS_SERVER_KEY()); + grpc::SslServerCredentialsOptions::PemKeyCertPair pkcp = {serverkey, servercert}; + grpc::SslServerCredentialsOptions ssl_opts( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + ssl_opts.pem_root_certs = rootcert; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + + // Create server credentials + std::shared_ptr server_creds; + server_creds = grpc::SslServerCredentials(ssl_opts); + builder.AddListeningPort(server_address, server_creds, &port_); + } else { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + } // Register all the services to this server. if (services_.empty()) { RAY_LOG(WARNING) << "No service is found when start grpc server " << name_;