diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index f8324b28dd938..f1ae523414e73 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -251,6 +251,8 @@ def get_address_info_from_redis_helper(redis_address, relevant_client = None for client_info in client_table: + if not client_info["Alive"]: + continue client_node_ip_address = client_info["NodeManagerAddress"] if (client_node_ip_address == node_ip_address or (client_node_ip_address == "127.0.0.1" diff --git a/python/ray/tests/test_component_failures_2.py b/python/ray/tests/test_component_failures_2.py index 2235c5745c593..6d108cac9e5d4 100644 --- a/python/ray/tests/test_component_failures_2.py +++ b/python/ray/tests/test_component_failures_2.py @@ -8,7 +8,11 @@ import ray import ray.ray_constants as ray_constants from ray.cluster_utils import Cluster -from ray.test_utils import RayTestTimeoutException, get_other_nodes +from ray.test_utils import ( + RayTestTimeoutException, + get_other_nodes, + wait_for_condition, +) SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM @@ -157,6 +161,29 @@ def test_raylet_failed(ray_start_cluster): True) +def test_get_address_info_after_raylet_died(ray_start_cluster_head): + cluster = ray_start_cluster_head + + def get_address_info(): + return ray._private.services.get_address_info_from_redis( + cluster.redis_address, + cluster.head_node.node_ip_address, + num_retries=1, + redis_password=cluster.redis_password) + + assert get_address_info()[ + "raylet_socket_name"] == cluster.head_node.raylet_socket_name + + cluster.head_node.kill_raylet() + wait_for_condition( + lambda: not cluster.global_state.node_table()[0]["Alive"]) + with pytest.raises(RuntimeError): + get_address_info() + + node2 = cluster.add_node() + assert get_address_info()["raylet_socket_name"] == node2.raylet_socket_name + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__]))