Skip to content

Commit

Permalink
Fix the node rank after resorting nodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
workingloong committed Jun 24, 2024
1 parent 1b21ea0 commit 961e324
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
13 changes: 7 additions & 6 deletions dlrover/python/master/elastic_training/net_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
@dataclass
class NodeTopologyMeta(object):
node_id: int = 0
node_rank: int = 0
process_num: int = 0
node_ip: str = ""
asw: str = ""
Expand Down Expand Up @@ -73,16 +74,16 @@ def sort(
asw_nodes: Dict[str, List[NodeTopologyMeta]] = {}
rank0_node = next(iter(nodes.values()))
rank0_asw = rank0_node.asw
for node_rank, meta in nodes.items():
for _, meta in nodes.items():
asw_nodes.setdefault(meta.asw, [])
asw_nodes[meta.asw].append((node_rank, meta))
asw_nodes[meta.asw].append(meta)

sorted_nodes: Dict[int, NodeTopologyMeta] = OrderedDict()
asw0_nodes = asw_nodes.pop(rank0_asw, [])
for node_rank, node_meta in asw0_nodes:
sorted_nodes[node_rank] = node_meta
for node_meta in asw0_nodes:
sorted_nodes[node_meta.node_rank] = node_meta

for node_rank, node_metas in asw_nodes.values():
for node_metas in asw_nodes.values():
for node_meta in node_metas:
sorted_nodes[node_rank] = node_meta
sorted_nodes[node_meta.node_rank] = node_meta
return sorted_nodes
2 changes: 2 additions & 0 deletions dlrover/python/master/elastic_training/rdzv_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def join_rendezvous(
asw, psw = self._topology_querier.query(node_ip)
meta = NodeTopologyMeta(
node_id=node_id,
node_rank=node_rank,
node_ip=node_ip,
process_num=local_world_size,
asw=asw,
Expand All @@ -263,6 +264,7 @@ def _map_node_rank_to_id(self, rank_dict):
for node_rank, v in rank_dict.items():
node_id = self._rdzv_nodes[node_rank].node_id
id_dict[node_id] = v
id_dict = dict(sorted(id_dict.items()))
return id_dict

def num_nodes_waiting(self):
Expand Down

0 comments on commit 961e324

Please sign in to comment.