diff --git a/BUILD.bazel b/BUILD.bazel index 4b6e4a45b0e24..8d75c149b7736 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -2352,6 +2352,7 @@ pyx_library( deps = [ "//:core_worker_lib", "//:global_state_accessor_lib", + "//:gcs_server_lib", "//:raylet_lib", "//:redis_client", "//:src/ray/ray_exported_symbols.lds", diff --git a/python/ray/_private/node.py b/python/ray/_private/node.py index 4550cadcb3c59..9142e4bb62a68 100644 --- a/python/ray/_private/node.py +++ b/python/ray/_private/node.py @@ -24,8 +24,9 @@ import ray._private.services import ray._private.utils from ray._private import storage -from ray._raylet import GcsClient +from ray._raylet import GcsClient, get_session_key_from_storage from ray._private.resource_spec import ResourceSpec +from ray._private.services import serialize_config, get_address from ray._private.utils import open_log, try_to_create_directory, try_to_symlink # Logger for this module. It should be configured at the entry point @@ -177,9 +178,15 @@ def __init__( # Register the temp dir. if head: - # date including microsecond - date_str = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S_%f") - self._session_name = f"session_{date_str}_{os.getpid()}" + # We expect this the first time we initialize a cluster, but not during + # subsequent restarts of the head node. + maybe_key = self.check_persisted_session_name() + if maybe_key is None: + # date including microsecond + date_str = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S_%f") + self._session_name = f"session_{date_str}_{os.getpid()}" + else: + self._session_name = ray._private.utils.decode(maybe_key) else: if ray_params.session_name is None: assert not self._default_worker @@ -317,6 +324,29 @@ def __init__( self.validate_ip_port(self.gcs_address) self._record_stats() + def check_persisted_session_name(self): + if self._ray_params.external_addresses is None: + return None + self._redis_address = self._ray_params.external_addresses[0] + redis_ip_address, redis_port, enable_redis_ssl = get_address( + self._redis_address, + ) + # Address is ip:port or redis://ip:port + if int(redis_port) < 0: + raise ValueError( + f"Invalid Redis port provided: {redis_port}." + "The port must be a non-negative integer." + ) + + return get_session_key_from_storage( + redis_ip_address, + int(redis_port), + self._ray_params.redis_password, + enable_redis_ssl, + serialize_config(self._config), + b"session_name", + ) + @staticmethod def validate_ip_port(ip_port): """Validates the address is in the ip:port format""" @@ -1173,12 +1203,22 @@ def _write_cluster_info_to_kv(self): ray_usage_lib.put_cluster_metadata(self.get_gcs_client()) # Make sure GCS is up. - self.get_gcs_client().internal_kv_put( + added = self.get_gcs_client().internal_kv_put( b"session_name", self._session_name.encode(), - True, + False, ray_constants.KV_NAMESPACE_SESSION, ) + if not added: + curr_val = self.get_gcs_client().internal_kv_get( + b"session_name", ray_constants.KV_NAMESPACE_SESSION + ) + assert curr_val != self._session_name, ( + f"Session name {self._session_name} does not match " + f"persisted value {curr_val}. Perhaps there was an " + f"error connecting to Redis." + ) + self.get_gcs_client().internal_kv_put( b"session_dir", self._session_dir.encode(), @@ -1213,13 +1253,9 @@ def start_head_processes(self): logger.debug( f"Process STDOUT and STDERR is being " f"redirected to {self._logs_dir}." ) - assert self._redis_address is None assert self._gcs_address is None assert self._gcs_client is None - if self._ray_params.external_addresses is not None: - self._redis_address = self._ray_params.external_addresses[0] - self.start_gcs_server() assert self.get_gcs_client() is not None self._write_cluster_info_to_kv() diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 8a3dc10c9396c..40c415bd2a3c2 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1294,6 +1294,25 @@ def read_log(filename, lines_to_read): return None, None +def get_address(redis_address): + parts = redis_address.split("://", 1) + enable_redis_ssl = False + if len(parts) == 1: + redis_ip_address, redis_port = parts[0].rsplit(":", 1) + else: + # rediss for SSL + if len(parts) != 2 or parts[0] not in ("redis", "rediss"): + raise ValueError( + f"Invalid redis address {redis_address}." + "Expected format is ip:port or redis://ip:port, " + "or rediss://ip:port for SSL." + ) + redis_ip_address, redis_port = parts[1].rsplit(":", 1) + if parts[0] == "rediss": + enable_redis_ssl = True + return redis_ip_address, redis_port, enable_redis_ssl + + def start_gcs_server( redis_address: str, log_dir: str, @@ -1339,21 +1358,12 @@ def start_gcs_server( f"--session-name={session_name}", ] if redis_address: - parts = redis_address.split("://", 1) - enable_redis_ssl = "false" - if len(parts) == 1: - redis_ip_address, redis_port = parts[0].rsplit(":", 1) - else: - if len(parts) != 2 or parts[0] not in ("redis", "rediss"): - raise ValueError(f"Invalid redis address {redis_address}") - redis_ip_address, redis_port = parts[1].rsplit(":", 1) - if parts[0] == "rediss": - enable_redis_ssl = "true" + redis_ip_address, redis_port, enable_redis_ssl = get_address(redis_address) command += [ f"--redis_address={redis_ip_address}", f"--redis_port={redis_port}", - f"--redis_enable_ssl={enable_redis_ssl}", + f"--redis_enable_ssl={'true' if enable_redis_ssl else 'false'}", ] if redis_password: command += [f"--redis_password={redis_password}"] diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index beda9a9fc23e0..a0faaac40514b 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -156,7 +156,7 @@ from ray.includes.libcoreworker cimport ( from ray.includes.ray_config cimport RayConfig from ray.includes.global_state_accessor cimport CGlobalStateAccessor -from ray.includes.global_state_accessor cimport RedisDelKeySync +from ray.includes.global_state_accessor cimport RedisDelKeySync, RedisGetKeySync from ray.includes.optional cimport ( optional, nullopt ) @@ -4579,3 +4579,25 @@ cdef void async_callback(shared_ptr[CRayObject] obj, def del_key_from_storage(host, port, password, use_ssl, key): return RedisDelKeySync(host, port, password, use_ssl, key) + + +def get_session_key_from_storage(host, port, password, use_ssl, config, key): + """ + Get the session key from the storage. + Intended to be used for session_name only. + Args: + host: The address of the owner (caller) of the + generator task. + port: The task ID of the generator task. + password: The redis password. + use_ssl: Whether to use SSL. + config: The Ray config. Used to get storage namespace. + key: The key to retrieve. + """ + cdef: + c_string data + result = RedisGetKeySync(host, port, password, use_ssl, config, key, &data) + if result: + return data + else: + return None diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index e9cd93e0e7a29..4c571f5d8390e 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -46,6 +46,72 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: const c_string &node_ip_address, c_string *node_to_connect) +cdef extern from * namespace "ray::gcs" nogil: + """ + #include + #include "ray/gcs/gcs_server/store_client_kv.h" + namespace ray { + namespace gcs { + + bool RedisGetKeySync(const std::string& host, + int32_t port, + const std::string& password, + bool use_ssl, + const std::string& config, + const std::string& key, + std::string* data) { + InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, + ray::RayLog::ShutDownRayLog, + "ray_init", + ray::RayLogLevel::WARNING, + "" /* log_dir */); + + RedisClientOptions options(host, port, password, false, use_ssl); + + std::string config_list; + RAY_CHECK(absl::Base64Unescape(config, &config_list)); + RayConfig::instance().initialize(config_list); + + instrumented_io_context io_service; + + auto redis_client = std::make_shared(options); + auto status = redis_client->Connect(io_service); + if(!status.ok()) { + RAY_LOG(ERROR) << "Failed to connect to redis: " << status.ToString(); + return false; + } + + auto cli = std::make_unique( + std::make_unique(std::move(redis_client))); + + bool ret_val = false; + cli->Get("session", key, [&](std::optional result) { + if (result.has_value()) { + *data = result.value(); + ret_val = true; + } else { + RAY_LOG(INFO) << "Failed to retrieve the key " << key + << " from persistent storage."; + ret_val = false; + } + }); + io_service.run_for(std::chrono::milliseconds(1000)); + + return ret_val; + } + + } + } + """ + c_bool RedisGetKeySync(const c_string& host, + c_int32_t port, + const c_string& password, + c_bool use_ssl, + const c_string& config, + const c_string& key, + c_string* data) + + cdef extern from * namespace "ray::gcs" nogil: """ #include diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 3f589089f83eb..fe4d3a59ab6b2 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -141,7 +141,6 @@ py_test_module_list( "test_multinode_failures_2.py", "test_node_manager.py", "test_object_assign_owner.py", - "test_placement_group.py", "test_placement_group_2.py", "test_placement_group_4.py", "test_placement_group_failover.py", @@ -224,6 +223,7 @@ py_test_module_list( py_test_module_list( files = [ "test_gcs_ha_e2e.py", + "test_gcs_ha_e2e_2.py", "test_memory_pressure.py", "test_node_labels.py", ], @@ -317,6 +317,7 @@ py_test_module_list( "test_reference_counting_2.py", "test_exit_observability.py", "test_usage_stats.py", + "test_placement_group.py", "test_placement_group_3.py", "test_placement_group_5.py", "test_cancel.py", diff --git a/python/ray/tests/conftest_docker.py b/python/ray/tests/conftest_docker.py index 9120022dbe5b9..dbd18ed233932 100644 --- a/python/ray/tests/conftest_docker.py +++ b/python/ray/tests/conftest_docker.py @@ -93,7 +93,11 @@ def print_logs(self): "9379", ], volumes={"{head_node_vol.name}": {"bind": "/tmp", "mode": "rw"}}, - environment={"RAY_REDIS_ADDRESS": "{redis.ips.primary}:6379"}, + environment={ + "RAY_REDIS_ADDRESS": "{redis.ips.primary}:6379", + "RAY_raylet_client_num_connect_attempts": "10", + "RAY_raylet_client_connect_timeout_milliseconds": "100", + }, wrapper_class=Container, ports={ "8000/tcp": None, @@ -118,7 +122,11 @@ def print_logs(self): "9379", ], volumes={"{worker_node_vol.name}": {"bind": "/tmp", "mode": "rw"}}, - environment={"RAY_REDIS_ADDRESS": "{redis.ips.primary}:6379"}, + environment={ + "RAY_REDIS_ADDRESS": "{redis.ips.primary}:6379", + "RAY_raylet_client_num_connect_attempts": "10", + "RAY_raylet_client_connect_timeout_milliseconds": "100", + }, wrapper_class=Container, ports={ "8000/tcp": None, diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index a4ba35d1756d9..ff7986c329e3d 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -355,22 +355,22 @@ def check_demands(n): @pytest.mark.skipif(enable_external_redis(), reason="Only valid in non redis env") def test_redis_not_available(monkeypatch, call_ray_stop_only): - monkeypatch.setenv("RAY_NUM_REDIS_GET_RETRIES", "2") + monkeypatch.setenv("RAY_redis_db_connect_retries", "5") monkeypatch.setenv("RAY_REDIS_ADDRESS", "localhost:12345") + p = subprocess.run( "ray start --head", shell=True, capture_output=True, ) assert "Could not establish connection to Redis" in p.stderr.decode() - assert "Please check" in p.stderr.decode() - assert "gcs_server.out for details" in p.stderr.decode() - assert "RuntimeError: Failed to start GCS" in p.stderr.decode() + assert "Please check " in p.stderr.decode() + assert "redis storage is alive or not." in p.stderr.decode() @pytest.mark.skipif(not enable_external_redis(), reason="Only valid in redis env") def test_redis_wrong_password(monkeypatch, external_redis, call_ray_stop_only): - monkeypatch.setenv("RAY_NUM_REDIS_GET_RETRIES", "2") + monkeypatch.setenv("RAY_redis_db_connect_retries", "5") p = subprocess.run( "ray start --head --redis-password=1234", shell=True, @@ -378,8 +378,6 @@ def test_redis_wrong_password(monkeypatch, external_redis, call_ray_stop_only): ) assert "RedisError: ERR AUTH called" in p.stderr.decode() - assert "Please check /tmp/ray/session" in p.stderr.decode() - assert "RuntimeError: Failed to start GCS" in p.stderr.decode() @pytest.mark.skipif(not enable_external_redis(), reason="Only valid in redis env") diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index dea003d5631ef..1d3611c7cba93 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -868,6 +868,36 @@ def check_raylet_healthy(): sleep(1) +def test_session_name(ray_start_cluster): + # Kill GCS and check that raylets kill themselves when not backed by Redis, + # and stay alive when backed by Redis. + # Raylets should kill themselves due to cluster ID mismatch in the + # non-persisted case. + cluster = ray_start_cluster + cluster.add_node() + cluster.wait_for_nodes() + + head_node = cluster.head_node + session_dir = head_node.get_session_dir_path() + + gcs_server_process = head_node.all_processes["gcs_server"][0].process + gcs_server_pid = gcs_server_process.pid + cluster.remove_node(head_node, allow_graceful=False) + # Wait to prevent the gcs server process becoming zombie. + gcs_server_process.wait() + wait_for_pid_to_exit(gcs_server_pid, 1000) + + # Add head node back + cluster.add_node() + head_node = cluster.head_node + new_session_dir = head_node.get_session_dir_path() + + if not enable_external_redis(): + assert session_dir != new_session_dir + else: + assert session_dir == new_session_dir + + @pytest.mark.parametrize( "ray_start_regular_with_external_redis", [ @@ -917,6 +947,34 @@ def check_raylet_healthy(): wait_for_condition(lambda: not check_raylet_healthy()) +def test_redis_logs(external_redis): + try: + import subprocess + + process = subprocess.Popen( + ["ray", "start", "--head"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + stdout, stderr = process.communicate(timeout=30) + print(stdout.decode()) + print(stderr.decode()) + assert "redis_context.cc" not in stderr.decode() + assert "redis_context.cc" not in stdout.decode() + assert "Resolve Redis address" not in stderr.decode() + assert "Resolve Redis address" not in stdout.decode() + # assert "redis_context.cc" not in result.output + finally: + from click.testing import CliRunner + import ray.scripts.scripts as scripts + + runner = CliRunner(env={"RAY_USAGE_STATS_PROMPT_ENABLED": "0"}) + runner.invoke( + scripts.stop, + [ + "--force", + ], + ) + + if __name__ == "__main__": import pytest diff --git a/python/ray/tests/test_gcs_ha_e2e.py b/python/ray/tests/test_gcs_ha_e2e.py index e9e64599c15f0..3a83afaa0c6b5 100644 --- a/python/ray/tests/test_gcs_ha_e2e.py +++ b/python/ray/tests/test_gcs_ha_e2e.py @@ -1,6 +1,8 @@ -import pytest import sys from time import sleep + +import pytest + from ray._private.test_utils import wait_for_condition from ray.tests.conftest_docker import * # noqa diff --git a/python/ray/tests/test_gcs_ha_e2e_2.py b/python/ray/tests/test_gcs_ha_e2e_2.py new file mode 100644 index 0000000000000..380f527be1719 --- /dev/null +++ b/python/ray/tests/test_gcs_ha_e2e_2.py @@ -0,0 +1,55 @@ +import pytest +import sys +from time import sleep +from ray._private.test_utils import wait_for_condition +from ray.tests.conftest_docker import * # noqa + + +@pytest.mark.skipif(sys.platform != "linux", reason="Only works on linux.") +def test_ray_session_name_preserved(docker_cluster): + get_nodes_script = """ +import ray +ray.init("auto") +print(ray._private.worker._global_node.session_name) +""" + head, worker = docker_cluster + + def get_session_name(to_head=True): + if to_head: + output = head.exec_run(cmd=f"python -c '{get_nodes_script}'") + else: + output = worker.exec_run(cmd=f"python -c '{get_nodes_script}'") + session_name = output.output.decode().strip().split("\n")[-1] + print("Output: ", output.output.decode().strip().split("\n")) + assert output.exit_code == 0 + return session_name + + # Make sure two nodes are alive + wait_for_condition(get_session_name, to_head=True) + session_name_head = get_session_name(to_head=True) + wait_for_condition(get_session_name, to_head=False) + session_name_worker = get_session_name(to_head=False) + assert session_name_head == session_name_worker + print("head killed") + head.kill() + + sleep(2) + + head.restart() + + wait_for_condition(get_session_name, to_head=True) + session_name_head_after_restart = get_session_name(to_head=True) + wait_for_condition(get_session_name, to_head=False) + session_name_worker_after_restart = get_session_name(to_head=False) + assert session_name_worker_after_restart == session_name_head_after_restart + assert session_name_head == session_name_head_after_restart + assert session_name_worker_after_restart == session_name_worker + + +if __name__ == "__main__": + import os + + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_ray_init.py b/python/ray/tests/test_ray_init.py index 38e351a8a025b..1e139d7490211 100644 --- a/python/ray/tests/test_ray_init.py +++ b/python/ray/tests/test_ray_init.py @@ -13,7 +13,7 @@ from ray.util.client.common import ClientObjectRef from ray.util.client.ray_client_helpers import ray_start_client_server from ray.util.client.worker import Worker -from ray._private.test_utils import wait_for_condition +from ray._private.test_utils import wait_for_condition, enable_external_redis @pytest.mark.skipif( @@ -234,6 +234,34 @@ def test_ray_init_using_hostname(ray_start_cluster): assert node_table[0].get("NodeManagerHostname", "") == hostname +def test_new_ray_instance_new_session_dir(shutdown_only): + ray.init() + session_dir = ray._private.worker._global_node.get_session_dir_path() + ray.shutdown() + ray.init() + if enable_external_redis(): + assert ray._private.worker._global_node.get_session_dir_path() == session_dir + else: + assert ray._private.worker._global_node.get_session_dir_path() != session_dir + + +def test_new_cluster_new_session_dir(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node() + ray.init(address=cluster.address) + session_dir = ray._private.worker._global_node.get_session_dir_path() + ray.shutdown() + cluster.shutdown() + cluster.add_node() + ray.init(address=cluster.address) + if enable_external_redis(): + assert ray._private.worker._global_node.get_session_dir_path() == session_dir + else: + assert ray._private.worker._global_node.get_session_dir_path() != session_dir + ray.shutdown() + cluster.shutdown() + + if __name__ == "__main__": import sys diff --git a/python/ray/tests/test_tempfile.py b/python/ray/tests/test_tempfile.py index 081df60044922..4bf838eefdabc 100644 --- a/python/ray/tests/test_tempfile.py +++ b/python/ray/tests/test_tempfile.py @@ -134,7 +134,7 @@ def check_all_log_file_exists(): assert sum(1 for filename in log_files if filename.startswith("worker")) == 4 socket_files = set(os.listdir(node.get_sockets_dir_path())) - assert socket_files == expected_socket_files + assert socket_files.issuperset(expected_socket_files) def test_tempdir_privilege(shutdown_only):