Skip to content

Commit

Permalink
[autoscaler v2] add test for node provider (ray-project#35593)
Browse files Browse the repository at this point in the history
Why are these changes needed?
add tests for node_provider v2 and refactor the mock code
  • Loading branch information
scv119 authored May 26, 2023
1 parent d5b7f49 commit ce16a2e
Show file tree
Hide file tree
Showing 10 changed files with 522 additions and 375 deletions.
16 changes: 12 additions & 4 deletions python/ray/autoscaler/_private/node_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,17 @@ def __init__(
self.node_types = node_types
self.index = str(index) if index is not None else ""

def launch_node(self, config: Dict[str, Any], count: int, node_type: str):
def launch_node(
self, config: Dict[str, Any], count: int, node_type: str
) -> Optional[Dict]:
self.log("Got {} nodes to launch.".format(count))
self._launch_node(config, count, node_type)
created_nodes = self._launch_node(config, count, node_type)
self.pending.dec(node_type, count)
return created_nodes

def _launch_node(self, config: Dict[str, Any], count: int, node_type: str):
def _launch_node(
self, config: Dict[str, Any], count: int, node_type: str
) -> Optional[Dict]:
if self.node_types:
assert node_type, node_type

Expand Down Expand Up @@ -100,8 +105,9 @@ def _launch_node(self, config: Dict[str, Any], count: int, node_type: str):

error_msg = None
full_exception = None
created_nodes = {}
try:
self.provider.create_node_with_resources(
created_nodes = self.provider.create_node_with_resources(
node_config, node_tags, count, resources
)
except NodeLaunchException as node_launch_exception:
Expand Down Expand Up @@ -158,6 +164,8 @@ def _launch_node(self, config: Dict[str, Any], count: int, node_type: str):
if full_exception is not None:
self.log(full_exception)

return created_nodes

def log(self, statement):
# launcher_class is "BaseNodeLauncher", or "NodeLauncher" if called
# from that subclass.
Expand Down
8 changes: 8 additions & 0 deletions python/ray/autoscaler/v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ py_test(
"//:ray_lib",
]
)

py_test(
name = "test_node_provider",
size = "small",
srcs = ["tests/test_node_provider.py"],
tags = ["team:core"],
deps = ["//:ray_lib",],
)
3 changes: 3 additions & 0 deletions python/ray/autoscaler/v2/instance_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def get_node_type_specific_config(
def get_config(self, config_name, default=None) -> Any:
return self._node_configs.get(config_name, default)

def get_raw_config_mutable(self) -> Dict[str, Any]:
return self._node_configs

@property
def restart_only(self) -> bool:
return self._node_configs.get("restart_only", False)
Expand Down
45 changes: 20 additions & 25 deletions python/ray/autoscaler/v2/instance_manager/node_provider.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Set, override
from typing import Dict, List, Set

from ray.autoscaler._private.node_launcher import BaseNodeLauncher
from ray.autoscaler.node_provider import NodeProvider as NodeProviderV1
from ray.autoscaler.tags import TAG_RAY_USER_NODE_TYPE
from ray.autoscaler.v2.instance_manager.config import NodeProviderConfig
from ray.core.generated.instance_manager_pb2 import Instance, InstanceType
from ray.core.generated.instance_manager_pb2 import Instance

logger = logging.getLogger(__name__)

Expand All @@ -17,14 +17,14 @@ class NodeProvider(metaclass=ABCMeta):
"""

@abstractmethod
def create_nodes(self, instance_type: InstanceType, count: int) -> List[str]:
def create_nodes(self, instance_type_name: str, count: int) -> List[str]:
"""Create new nodes synchronously, returns all non-terminated nodes in the cluster.
Note that create_nodes could fail partially.
"""
pass

@abstractmethod
def async_terminate_nodes(self, cloud_instance_ids: List[str]) -> None:
def terminate_node(self, cloud_instance_id: str) -> None:
"""
Terminate nodes asynchronously, returns immediately."""
pass
Expand All @@ -37,7 +37,7 @@ def get_non_terminated_nodes(
pass

@abstractmethod
def get_nodes_by_cloud_id(
def get_nodes_by_cloud_instance_id(
self,
cloud_instance_ids: List[str],
) -> Dict[str, Instance]:
Expand Down Expand Up @@ -80,37 +80,32 @@ def _filter_instances(
filtered[instance_id] = instance
return filtered

@override
def create_nodes(self, instance_type: InstanceType, count: int) -> List[Instance]:
result = self._node_launcher.launch_node(
self._config.get_node_config(instance_type.name),
def create_nodes(self, instance_type_name: str, count: int) -> List[Instance]:
created_nodes = self._node_launcher.launch_node(
self._config.get_raw_config_mutable(),
count,
instance_type.name,
instance_type_name,
)
# TODO: we should handle failures where the instance type is
# not available
if result:
if created_nodes:
return [
self._get_instance(cloud_instance_id)
for cloud_instance_id in result.keys()
for cloud_instance_id in created_nodes.keys()
]
return []

@override
def async_terminate_nodes(self, clould_instance_ids: List[str]) -> None:
self._provider.terminate_node(clould_instance_ids)
def terminate_node(self, clould_instance_id: str) -> None:
self._provider.terminate_node(clould_instance_id)

@override
def is_readonly(self) -> bool:
return self._provider.is_readonly()

@override
def get_non_terminated_nodes(self):
clould_instance_ids = self._provider.non_terminated_nodes()
return self.get_nodes_by_id(clould_instance_ids)
clould_instance_ids = self._provider.non_terminated_nodes({})
return self.get_nodes_by_cloud_instance_id(clould_instance_ids)

@override
def get_nodes_by_cloud_id(
def get_nodes_by_cloud_instance_id(
self,
cloud_instance_ids: List[str],
) -> Dict[str, Instance]:
Expand All @@ -123,12 +118,12 @@ def _get_instance(self, cloud_instance_id: str) -> Instance:
instance = Instance()
instance.cloud_instance_id = cloud_instance_id
if self._provider.is_running(cloud_instance_id):
instance.state = Instance.STARTING
instance.status = Instance.STARTING
elif self._provider.is_terminated(cloud_instance_id):
instance.state = Instance.STOPPED
instance.status = Instance.STOPPED
else:
instance.state = Instance.INSTANCE_STATUS_UNSPECIFIED
instance.interal_ip = self._provider.internal_ip(cloud_instance_id)
instance.status = Instance.INSTANCE_STATUS_UNSPECIFIED
instance.internal_ip = self._provider.internal_ip(cloud_instance_id)
instance.external_ip = self._provider.external_ip(cloud_instance_id)
instance.instance_type = self._provider.node_tags(cloud_instance_id)[
TAG_RAY_USER_NODE_TYPE
Expand Down
107 changes: 107 additions & 0 deletions python/ray/autoscaler/v2/tests/test_node_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# coding: utf-8
import os
import sys
import unittest

import pytest # noqa

from ray._private.test_utils import load_test_config
from ray.autoscaler._private.event_summarizer import EventSummarizer
from ray.autoscaler._private.node_launcher import BaseNodeLauncher
from ray.autoscaler._private.node_provider_availability_tracker import (
NodeProviderAvailabilityTracker,
)
from ray.autoscaler.node_launch_exception import NodeLaunchException
from ray.autoscaler.v2.instance_manager.config import NodeProviderConfig
from ray.autoscaler.v2.instance_manager.node_provider import NodeProviderAdapter
from ray.core.generated.instance_manager_pb2 import Instance
from ray.tests.autoscaler_test_utils import MockProvider


class FakeCounter:
def dec(self, *args, **kwargs):
pass


class NodeProviderTest(unittest.TestCase):
def setUp(self):
self.base_provider = MockProvider()
self.availability_tracker = NodeProviderAvailabilityTracker()
self.node_launcher = BaseNodeLauncher(
self.base_provider,
FakeCounter(),
EventSummarizer(),
self.availability_tracker,
)
self.instance_config_provider = NodeProviderConfig(
load_test_config("test_ray_complex.yaml")
)
self.node_provider = NodeProviderAdapter(
self.base_provider, self.node_launcher, self.instance_config_provider
)

def test_node_providers_pass_through(self):
nodes = self.node_provider.create_nodes("worker_nodes1", 1)
assert len(nodes) == 1
assert nodes[0] == Instance(
instance_type="worker_nodes1",
cloud_instance_id="0",
internal_ip="172.0.0.0",
external_ip="1.2.3.4",
status=Instance.INSTANCE_STATUS_UNSPECIFIED,
)
self.assertEqual(len(self.base_provider.mock_nodes), 1)
self.assertEqual(self.node_provider.get_non_terminated_nodes(), {"0": nodes[0]})
nodes1 = self.node_provider.create_nodes("worker_nodes", 2)
assert len(nodes1) == 2
assert nodes1[0] == Instance(
instance_type="worker_nodes",
cloud_instance_id="1",
internal_ip="172.0.0.1",
external_ip="1.2.3.4",
status=Instance.INSTANCE_STATUS_UNSPECIFIED,
)
assert nodes1[1] == Instance(
instance_type="worker_nodes",
cloud_instance_id="2",
internal_ip="172.0.0.2",
external_ip="1.2.3.4",
status=Instance.INSTANCE_STATUS_UNSPECIFIED,
)
self.assertEqual(
self.node_provider.get_non_terminated_nodes(),
{"0": nodes[0], "1": nodes1[0], "2": nodes1[1]},
)
self.assertEqual(
self.node_provider.get_nodes_by_cloud_instance_id(["0"]),
{
"0": nodes[0],
},
)
self.node_provider.terminate_node("0")
self.assertEqual(
self.node_provider.get_non_terminated_nodes(),
{"1": nodes1[0], "2": nodes1[1]},
)
self.assertFalse(self.node_provider.is_readonly())

def test_create_node_failure(self):
self.base_provider.error_creates = NodeLaunchException(
"hello", "failed to create node", src_exc_info=None
)
self.assertEqual(self.node_provider.create_nodes("worker_nodes1", 1), [])
self.assertEqual(len(self.base_provider.mock_nodes), 0)
self.assertTrue(
"worker_nodes1" in self.availability_tracker.summary().node_availabilities
)
self.assertEqual(
self.node_provider.get_non_terminated_nodes(),
{},
)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))
Loading

0 comments on commit ce16a2e

Please sign in to comment.