Skip to content

Commit

Permalink
[Core] Add TLS/SSL support to gRPC channels (ray-project#18631)
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarknagg committed Oct 21, 2021
1 parent 6d23fb1 commit 5a05e89
Show file tree
Hide file tree
Showing 30 changed files with 519 additions and 62 deletions.
8 changes: 4 additions & 4 deletions dashboard/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def __init__(self,
assert self.ppid > 0
logger.info("Parent pid is %s", self.ppid)
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
self.grpc_port = self.server.add_insecure_port(
f"[::]:{self.dashboard_agent_port}")
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, f"[::]:{self.dashboard_agent_port}")
logger.info("Dashboard agent grpc address: %s:%s", self.ip,
self.grpc_port)
self.aioredis_client = None
options = (("grpc.enable_http_proxy", 0), )
self.aiogrpc_raylet_channel = aiogrpc.insecure_channel(
f"{self.ip}:{self.node_manager_port}", options=options)
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True)
self.http_session = None
ip, port = redis_address.split(":")
self.gcs_client = connect_to_gcs(ip, int(port), redis_password)
Expand Down
10 changes: 6 additions & 4 deletions dashboard/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from grpc.experimental import aio as aiogrpc
import grpc

import ray._private.utils
import ray._private.services
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
Expand Down Expand Up @@ -56,7 +57,7 @@ async def get_gcs_address_with_retry(redis_client) -> str:

class GCSHealthCheckThread(threading.Thread):
def __init__(self, gcs_address: str):
self.grpc_gcs_channel = grpc.insecure_channel(
self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, options=GRPC_CHANNEL_OPTIONS)
self.gcs_heartbeat_info_stub = (
gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
Expand Down Expand Up @@ -116,7 +117,8 @@ def __init__(self, http_host, http_port, http_port_retries, redis_address,
ip, port = redis_address.split(":")
self.gcs_client = connect_to_gcs(ip, int(port), redis_password)
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
self.grpc_port = self.server.add_insecure_port("[::]:0")
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, "[::]:0")
logger.info("Dashboard head grpc address: %s:%s", self.ip,
self.grpc_port)

Expand Down Expand Up @@ -188,8 +190,8 @@ async def run(self):

# Waiting for GCS is ready.
gcs_address = await get_gcs_address_with_retry(self.aioredis_client)
self.aiogrpc_gcs_channel = aiogrpc.insecure_channel(
gcs_address, options=GRPC_CHANNEL_OPTIONS)
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)

self.health_check_thread = GCSHealthCheckThread(gcs_address)
self.health_check_thread.start()
Expand Down
7 changes: 4 additions & 3 deletions dashboard/modules/actor/actor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ async def _update_stubs(self, change):
address = "{}:{}".format(node_info["nodeManagerAddress"],
int(node_info["nodeManagerPort"]))
options = (("grpc.enable_http_proxy", 0), )
channel = aiogrpc.insecure_channel(address, options=options)
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub

Expand Down Expand Up @@ -180,8 +181,8 @@ async def kill_actor(self, req) -> aiohttp.web.Response:
return rest_response(success=False, message="Bad Request")
try:
options = (("grpc.enable_http_proxy", 0), )
channel = aiogrpc.insecure_channel(
f"{ip_address}:{port}", options=options)
channel = ray._private.utils.init_grpc_channel(
f"{ip_address}:{port}", options=options, asynchronous=True)
stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)

await stub.KillActor(
Expand Down
8 changes: 5 additions & 3 deletions dashboard/modules/event/event_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import asyncio
import logging
from typing import Union
from grpc.experimental import aio as aiogrpc

import ray._private.utils as utils
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.consts as dashboard_consts
from ray.ray_constants import env_bool
Expand Down Expand Up @@ -46,8 +46,10 @@ async def _connect_to_dashboard(self):
if dashboard_rpc_address:
logger.info("Report events to %s", dashboard_rpc_address)
options = (("grpc.enable_http_proxy", 0), )
channel = aiogrpc.insecure_channel(
dashboard_rpc_address, options=options)
channel = utils.init_grpc_channel(
dashboard_rpc_address,
options=options,
asynchronous=True)
return event_pb2_grpc.ReportEventServiceStub(channel)
except Exception:
logger.exception("Connect to dashboard failed.")
Expand Down
6 changes: 4 additions & 2 deletions dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import aiohttp.web
from aioredis.pubsub import Receiver
from grpc.experimental import aio as aiogrpc

import ray._private.utils
import ray._private.gcs_utils as gcs_utils
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.modules.job import job_consts
Expand Down Expand Up @@ -52,7 +52,9 @@ async def submit_job(self, req) -> aiohttp.web.Response:
ip = DataSource.node_id_to_ip[node_id]
address = f"{ip}:{ports[1]}"
options = (("grpc.enable_http_proxy", 0), )
channel = aiogrpc.insecure_channel(address, options=options)
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True)

stub = job_agent_pb2_grpc.JobAgentServiceStub(channel)
request = job_agent_pb2.InitializeJobEnvRequest(
job_description=json.dumps(job_description_data))
Expand Down
4 changes: 2 additions & 2 deletions dashboard/modules/node/node_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import aiohttp.web
from aioredis.pubsub import Receiver
from grpc.experimental import aio as aiogrpc

import ray._private.utils
import ray._private.gcs_utils as gcs_utils
Expand Down Expand Up @@ -68,7 +67,8 @@ async def _update_stubs(self, change):
address = "{}:{}".format(node_info["nodeManagerAddress"],
int(node_info["nodeManagerPort"]))
options = (("grpc.enable_http_proxy", 0), )
channel = aiogrpc.insecure_channel(address, options=options)
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub

Expand Down
5 changes: 2 additions & 3 deletions dashboard/modules/reporter/reporter_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import aiohttp.web
from aioredis.pubsub import Receiver
from grpc.experimental import aio as aiogrpc

import ray
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
Expand Down Expand Up @@ -38,8 +37,8 @@ async def _update_stubs(self, change):
node_id, ports = change.new
ip = DataSource.node_id_to_ip[node_id]
options = (("grpc.enable_http_proxy", 0), )
channel = aiogrpc.insecure_channel(
f"{ip}:{ports[1]}", options=options)
channel = ray._private.utils.init_grpc_channel(
f"{ip}:{ports[1]}", options=options, asynchronous=True)
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
self._stubs[ip] = stub

Expand Down
1 change: 0 additions & 1 deletion dashboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from collections import namedtuple
from collections.abc import MutableMapping, Mapping, Sequence
from typing import Any

from google.protobuf.json_format import MessageToDict

import ray.dashboard.consts as dashboard_consts
Expand Down
22 changes: 22 additions & 0 deletions doc/source/configure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,28 @@ to localhost when the ray is started using ``ray.init``.
See the `Redis security documentation <https://redis.io/topics/security>`__
for more information.
TLS Authentication
------------------
Ray can be configured to use TLS on it's gRPC channels.
This has means that connecting to the Ray client on the head node will
require an appropriate set of credentials and also that data exchanged between
various processes (client, head, workers) will be encrypted.
Enabling TLS will cause a performance hit due to the extra overhead of mutual
authentication and encryption.
Testing has shown that this overhead is large for small workloads and becomes
relatively smaller for large workloads.
The exact overhead will depend on the nature of your workload.
TLS is enabled by setting environment variables.
- ``RAY_USE_TLS``: Either 1 or 0 to use/not-use TLS. If this is set to 1 then all of the environment variables below must be set. Default: 0.
- ``RAY_TLS_SERVER_CERT``: Location of a `certificate file` which is presented to other endpoints so as to achieve mutual authentication.
- ``RAY_TLS_SERVER_KEY``: Location of a `private key file` which is the cryptographic means to prove to other endpoints that you are the authorized user of a given certificate.
- ``RAY_TLS_CA_CERT``: Location of a `CA certificate file` which allows TLS to decide whether an endpoint's certificate has been signed by the correct authority.
Java Applications
-----------------
Expand Down
34 changes: 34 additions & 0 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
from typing import Optional, Any, List, Dict
from contextlib import redirect_stdout, redirect_stderr
import yaml
import pytest
import tempfile

import ray
import ray._private.services
import ray._private.utils
import ray._private.gcs_utils as gcs_utils
from ray._private.tls_utils import generate_self_signed_tls_certs
from ray.util.queue import Queue, _QueueActor, Empty
from ray.scripts.scripts import main as ray_main
try:
Expand Down Expand Up @@ -691,3 +694,34 @@ def is_placement_group_removed(pg):
if "state" not in table:
return False
return table["state"] == "REMOVED"


def setup_tls():
"""Sets up required environment variables for tls"""
if sys.platform == "darwin":
pytest.skip("Cryptography doesn't install in Mac build pipeline")
cert, key = generate_self_signed_tls_certs()
temp_dir = tempfile.mkdtemp("ray-test-certs")
cert_filepath = os.path.join(temp_dir, "server.crt")
key_filepath = os.path.join(temp_dir, "server.key")
with open(cert_filepath, "w") as fh:
fh.write(cert)
with open(key_filepath, "w") as fh:
fh.write(key)

os.environ["RAY_USE_TLS"] = "1"
os.environ["RAY_TLS_SERVER_CERT"] = cert_filepath
os.environ["RAY_TLS_SERVER_KEY"] = key_filepath
os.environ["RAY_TLS_CA_CERT"] = cert_filepath

return key_filepath, cert_filepath, temp_dir


def teardown_tls(key_filepath, cert_filepath, temp_dir):
os.remove(key_filepath)
os.remove(cert_filepath)
os.removedirs(temp_dir)
del os.environ["RAY_USE_TLS"]
del os.environ["RAY_TLS_SERVER_CERT"]
del os.environ["RAY_TLS_SERVER_KEY"]
del os.environ["RAY_TLS_CA_CERT"]
90 changes: 90 additions & 0 deletions python/ray/_private/tls_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import datetime
import os
import socket

import grpc


def generate_self_signed_tls_certs():
"""Create self-signed key/cert pair for testing.
This method requires the library ``cryptography`` be installed.
"""
try:
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
except ImportError:
raise ImportError(
"Using `Security.temporary` requires `cryptography`, please "
"install it using either pip or conda")
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend())
key_contents = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode()

ray_interal = x509.Name(
[x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")])
# This is the same logic used by the GCS server to acquire a
# private/interal IP address to listen on. If we just use localhost +
# 127.0.0.1 then we won't be able to connect to the GCS and will get
# an error like "No match found for server name: 192.168.X.Y"
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
private_ip_address = s.getsockname()[0]
s.close()
altnames = x509.SubjectAlternativeName([
x509.DNSName(socket.gethostbyname(
socket.gethostname())), # Probably 127.0.0.1
x509.DNSName("127.0.0.1"),
x509.DNSName(private_ip_address), # 192.168.*.*
x509.DNSName("localhost"),
])
now = datetime.datetime.utcnow()
cert = (x509.CertificateBuilder().subject_name(ray_interal).issuer_name(
ray_interal).add_extension(altnames, critical=False).public_key(
key.public_key()).serial_number(
x509.random_serial_number()).not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=365)).sign(
key, hashes.SHA256(), default_backend()))

cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode()

return cert_contents, key_contents


def add_port_to_grpc_server(server, address):
if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
server_cert_chain, private_key, ca_cert = load_certs_from_env()
credentials = grpc.ssl_server_credentials(
[(private_key, server_cert_chain)],
root_certificates=ca_cert,
require_client_auth=ca_cert is not None)
return server.add_secure_port(address, credentials)
else:
return server.add_insecure_port(address)


def load_certs_from_env():
tls_env_vars = [
"RAY_TLS_SERVER_CERT", "RAY_TLS_SERVER_KEY", "RAY_TLS_CA_CERT"
]
if any(v not in os.environ for v in tls_env_vars):
raise RuntimeError(
"If the environment variable RAY_USE_TLS is set to true "
"then RAY_TLS_SERVER_CERT, RAY_TLS_SERVER_KEY and "
"RAY_TLS_CA_CERT must also be set.")

with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f:
server_cert_chain = f.read()
with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f:
private_key = f.read()
with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f:
ca_cert = f.read()

return server_cert_chain, private_key, ca_cert
23 changes: 22 additions & 1 deletion python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import tempfile
import threading
import time
from typing import Optional
from typing import Optional, Sequence, Tuple, Any
import uuid
import grpc
import warnings
from grpc.experimental import aio as aiogrpc

import inspect
from inspect import signature
Expand All @@ -25,6 +27,7 @@
import ray
import ray._private.gcs_utils as gcs_utils
import ray.ray_constants as ray_constants
from ray._private.tls_utils import load_certs_from_env

# Import psutil after ray so the packaged version is used.
import psutil
Expand Down Expand Up @@ -1109,6 +1112,24 @@ def validate_namespace(namespace: str):
"Pass None to not specify a namespace.")


def init_grpc_channel(address: str,
options: Optional[Sequence[Tuple[str, Any]]] = None,
asynchronous: bool = False):
grpc_module = aiogrpc if asynchronous else grpc
if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
server_cert_chain, private_key, ca_cert = load_certs_from_env()
credentials = grpc.ssl_channel_credentials(
certificate_chain=server_cert_chain,
private_key=private_key,
root_certificates=ca_cert)
channel = grpc_module.secure_channel(
address, credentials, options=options)
else:
channel = grpc_module.insecure_channel(address, options=options)

return channel


def check_dashboard_dependencies_installed() -> bool:
"""Returns True if Ray Dashboard dependencies are installed.
Expand Down
Loading

0 comments on commit 5a05e89

Please sign in to comment.