Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into add_output_for_version
Browse files Browse the repository at this point in the history
  • Loading branch information
BalaBalaYi committed Jul 13, 2024
2 parents 684aea2 + 30dd770 commit c21501a
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 33 deletions.
7 changes: 6 additions & 1 deletion dlrover/python/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,20 @@ def find_free_port_for_hccl(start=60000) -> int:
if end > 65000:
end = 65000
logger.info(f"Try to find available port for hccl from {start}")
checking_port = 0
while True:
try:
cur_end = cur_start + AscendConstants.NPU_PER_NODE
for port in range(cur_start, cur_end):
checking_port = port
find_free_port(port)
logger.info(f"Find available port start from: {cur_start}")
break
except OSError:
cur_start = cur_start + AscendConstants.NPU_PER_NODE
if checking_port > 0:
cur_start = checking_port + 1
else:
cur_start = cur_start + AscendConstants.NPU_PER_NODE
if cur_start > end:
cur_start = 0
break
Expand Down
5 changes: 4 additions & 1 deletion dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ class MasterClient(Singleton):
_instance_lock = threading.Lock()

def __init__(self, master_addr, node_id, node_type, timeout=5):
logger.info(
f"Build master client with master_addr: {master_addr}, "
f"node_id: {node_id}, node_type: {node_type}."
)
self._timeout = timeout
self._master_addr = master_addr
self._channel = grpc.build_channel(master_addr)
logger.info("dlrover master addr is %s" % self._master_addr)
self._stub = elastic_training_pb2_grpc.MasterStub(self._channel)
self._node_id = node_id
self._node_type = node_type
Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def sync_training_ports(self):
port = 0
logger.info("synchronize worker training ports...")
count = 0
max_count = 60
max_count = 120
while True:
if count >= max_count:
logger.error(
Expand All @@ -763,7 +763,7 @@ def sync_training_ports(self):
port = find_free_port_for_hccl(start_port)
if port == 0:
logger.error(
"fail to find available ports between 60000 and 70000"
f"fail to find available ports from {start_port}"
)
break
resp = self._client.sync_training_ports(port)
Expand Down
52 changes: 45 additions & 7 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,29 @@ def _monitor_node_heart_beat(self):

def _get_dead_node_event(self, window_interval=300) -> List[NodeEvent]:
now = time.time()
dead_events = []
dead_events: List[NodeEvent] = []
for _, nodes in self._job_nodes.items():
for _, node in nodes.items():
if (
node.heartbeat_time > 0
and now - node.heartbeat_time > window_interval
and node.start_time
and node.create_time
and node.status == NodeStatus.RUNNING
):
if (
node.heartbeat_time <= node.start_time
or node.heartbeat_time <= node.create_time
):
logger.warning(
f"Skip dead node judgement for "
f"node: {node.id}-{node.name} "
f"because heartbeat time < create/start time. "
f"Reset heartbeat time to 0."
)
node.heartbeat_time = 0
continue

event_node = copy.deepcopy(node)
event_node.status = NodeStatus.FAILED
event_node.exit_reason = NodeExitReason.NO_HEARTBEAT
Expand All @@ -388,8 +403,11 @@ def _get_dead_node_event(self, window_interval=300) -> List[NodeEvent]:
TrainingExceptionLevel.NODE_ERROR,
)
logger.warning(
f"The node {node.name} has not sent a heartbeat "
f"for over {window_interval} seconds."
f"The node {node.id}-{node.name} has not sent a "
f"heartbeat for over {window_interval} seconds, "
f"last heartbeat: {node.heartbeat_time}, "
f"created at: {node.create_time}, "
f"started at: {node.start_time}."
)
return dead_events

Expand Down Expand Up @@ -472,6 +490,22 @@ def _process_event(self, event: NodeEvent):
return
else:
cur_node = self._job_nodes[node_type][node_id]
logger.info(
f"Update node({cur_node.id}), "
f"name: {cur_node.name}->{event.node.name}, "
f"start_time: {cur_node.start_time}"
f"->{event.node.start_time}, "
f"create_time: {cur_node.create_time}"
f"->{event.node.create_time}, "
f"host_name: {cur_node.host_name}"
f"->{event.node.host_name},"
f"host_ip: {cur_node.host_ip}"
f"->{event.node.host_ip}, "
f"restart_training: {cur_node.restart_training}"
f"->{event.node.restart_training}, "
f"relaunch_count: {cur_node.relaunch_count}"
f"->{event.node.relaunch_count}"
)
cur_node.update_info(
name=event.node.name,
start_time=event.node.start_time,
Expand Down Expand Up @@ -854,10 +888,14 @@ def verify_restarting_worker_training(self, node_type, node_id):
return self._worker_manager.verify_restarting_training(node_id)

def collect_node_heart_beat(self, node_type, node_id, timestamp):
node = self._job_nodes[node_type][node_id]
if node.heartbeat_time == 0:
logger.info(f"Start receiving heartbeat from node {node.name}")
node.heartbeat_time = timestamp
with self._lock:
node = self._job_nodes[node_type][node_id]
if node.heartbeat_time == 0:
logger.info(
f"Start receiving heartbeat from node {node_id}"
f"-{node.name}"
)
node.heartbeat_time = timestamp


def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager:
Expand Down
14 changes: 5 additions & 9 deletions dlrover/python/master/node/training_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Dict, List

from dlrover.python.common.constants import (
AscendConstants,
DistributionStrategy,
NodeResourceLimit,
NodeStatus,
Expand Down Expand Up @@ -392,9 +391,7 @@ def sync_node_training_port(self, node_id, port) -> SyncNodeTrainingPorts:
next_check_port=self._next_check_node_training_port,
)
self._recv_node_training_ports[node_id] = port
logger.info(
f"recv ports from: {self._recv_node_training_ports.keys()}"
)
logger.info(f"recv ports: {self._recv_node_training_ports.keys()}")
if len(self._recv_node_training_ports) == self._n_node:
min_port = 0
max_port = 0
Expand All @@ -404,13 +401,12 @@ def sync_node_training_port(self, node_id, port) -> SyncNodeTrainingPorts:
if max_port < recv_port:
max_port = recv_port
if min_port != max_port:
self._recv_node_training_ports.clear()
self._next_check_node_training_port = max_port
logger.info(
f"fail to sync node training ports: "
f"{self._recv_node_training_ports}"
)
self._recv_node_training_ports.clear()
self._next_check_node_training_port = (
max_port + AscendConstants.NPU_PER_NODE
f"{self._recv_node_training_ports}, "
f"next sync port: {max_port}"
)
return SyncNodeTrainingPorts(
training_port=0,
Expand Down
4 changes: 3 additions & 1 deletion dlrover/python/master/stats/stats_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import yaml

from dlrover.python.common.log import default_logger as logger


def parse_json_file(file_path):
data = None
Expand Down Expand Up @@ -43,7 +45,7 @@ def load(self):
elif self.file_path.endswith("yaml"):
data = parse_yaml_file(self.file_path)
else:
print("error") # to do: logging
logger.eror("Invalid file format.")
self.data = data
return data

Expand Down
1 change: 0 additions & 1 deletion dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ def test_get_free_port(self):
s.bind(("", 10000))
os.environ["HOST_PORTS"] = "10000"
port = agent._get_free_port()
print(port)
s.close()
self.assertTrue(port != 10000)

Expand Down
1 change: 0 additions & 1 deletion dlrover/python/tests/test_elasticjob_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,4 @@ def test_generate_scaler_crd_by_plan(self):
],
"psHosts": ["test-ps-0:2222", "test-ps-1:2222"],
}
print(scaler_crd.spec.to_dict())
self.assertDictEqual(scaler_crd.spec.to_dict(), expected_dict)
2 changes: 1 addition & 1 deletion dlrover/python/tests/test_grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_find_free_port_for_hccl(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 60003))
port = find_free_port_for_hccl()
self.assertEqual(port, 60016)
self.assertEqual(port, 60004)


if __name__ == "__main__":
Expand Down
47 changes: 44 additions & 3 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from unittest import mock

Expand Down Expand Up @@ -58,6 +59,7 @@
from dlrover.python.master.watcher.base_watcher import Node, NodeEvent
from dlrover.python.scheduler.job import LocalJobArgs
from dlrover.python.tests.test_utils import (
MockK8sAllreduceJobArgs,
MockK8sPSJobArgs,
create_task_manager,
mock_k8s_client,
Expand Down Expand Up @@ -274,11 +276,20 @@ def test_get_dead_node_event(self):
node.status = NodeStatus.RUNNING
events = manager._get_dead_node_event()
self.assertEqual(len(events), 0)
for node in manager._job_nodes[NodeType.WORKER].values():
for index, node in enumerate(
manager._job_nodes[NodeType.WORKER].values()
):
node.status = NodeStatus.RUNNING
node.heartbeat_time = time.time() - 500
now = time.time()
node.heartbeat_time = now - 500
if index == 0:
node.create_time = now - 400
node.start_time = now - 300
else:
node.create_time = now - 700
node.start_time = now - 600
events = manager._get_dead_node_event()
self.assertEqual(len(events), 3)
self.assertEqual(len(events), 2)

def test_relaunch_training_master(self):
params = MockK8sPSJobArgs()
Expand Down Expand Up @@ -521,6 +532,36 @@ def test_start_and_stop(self):
self.assertIn("node_heart_beat_monitor", active_threads_name)
manager.stop()

def test_concurrency_heart_beat_collecting(self):
params = MockK8sAllreduceJobArgs()
worker_size = 10000
params.initilize(worker_size)
manager = create_job_manager(params, SpeedMonitor())
manager.start()

self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), worker_size)
for i, node in manager._job_nodes[NodeType.WORKER].items():
self.assertEqual(node.id, i)
self.assertEqual(node.heartbeat_time, 0)
futures = []
with ThreadPoolExecutor(max_workers=100) as executor:
for i in range(worker_size):
futures.append(
executor.submit(
manager.collect_node_heart_beat, NodeType.WORKER, i, i
)
)

for future in futures:
future.result()

self.assertEqual(len(futures), worker_size)
for i, node in manager._job_nodes[NodeType.WORKER].items():
self.assertEqual(node.id, i)
self.assertEqual(node.heartbeat_time, i)

manager.stop()


class LocalJobManagerTest(unittest.TestCase):
def test_local_job_manager(self):
Expand Down
2 changes: 0 additions & 2 deletions dlrover/python/tests/test_sharding_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def test_sharding_client(self):
for i in range(loop):
shard = data_shard_service.fetch_shard()
if i < 10 and i != 5:
print(f"i : {i}, shard : {shard.start}")
self.assertIsNotNone(shard)
data_shard_service.report_batch_done(task_ids=[i])
elif i == 5:
Expand Down Expand Up @@ -103,7 +102,6 @@ def test_index_sharding_client(self):
shuffled = False
for i in range(len(indices)):
if i != indices[i]:
print(i, indices[i])
shuffled = True
break
self.assertFalse(shuffled)
Expand Down
1 change: 0 additions & 1 deletion dlrover/python/tests/test_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_dataset_checkpoint(self):
checkpoint_str = checkpoint.to_json()

checkpoint_dict = json.loads(checkpoint_str)
print(checkpoint_dict)
self.assertDictEqual(
checkpoint_dict,
{
Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def __init__(self):
PlatformType.KUBERNETES, "default", "test"
)

def initilize(self):
def initilize(self, worker_count=16):
worker_resource = NodeGroupResource(
16, NodeResource(1, 4096, "a100", 8)
worker_count, NodeResource(1, 4096, "a100", 8)
)
self.node_args[NodeType.WORKER] = NodeArgs(
worker_resource, True, 3, 0, ""
Expand Down
1 change: 0 additions & 1 deletion dlrover/python/util/state/store_mananger.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def store_type(self):
def build_store(self):
if self.memory_store is None:
self.memory_store = MemoryStore(self, self.jobname, "test")
print(self.memory_store)
return self.memory_store

@classmethod
Expand Down

0 comments on commit c21501a

Please sign in to comment.