Skip to content

Commit

Permalink
Fix test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
workingloong committed Jun 3, 2024
1 parent 1bc4ff6 commit d9c441f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
4 changes: 4 additions & 0 deletions dlrover/python/common/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def get_local_rank():
return int(os.getenv("LOCAL_RANK", 0))


def get_rank():
return int(os.getenv("RANK", 0))


def get_group_world_size():
return int(os.getenv("GROUP_WORLD_SIZE", 1))

Expand Down
46 changes: 26 additions & 20 deletions dlrover/trainer/torch/flash_checkpoint/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.

from abc import ABCMeta, abstractmethod
from typing import Dict
from typing import Dict, List

import torch
import torch.distributed as dist
Expand All @@ -31,9 +31,15 @@ def __init__(self, replica_count) -> None:
self.local_rank = env_utils.get_local_rank()
self.local_world_size = env_utils.get_local_world_size()
self.node_rank = env_utils.get_node_rank()
self.rank = dist.get_rank()
self.node_num = env_utils.get_node_num()
self.current_device = torch.device("cpu")
self._rank_shms: Dict[int, SharedMemoryHandler] = {}
self._backup_ranks: List[int] = []
self._backup_group = None
if dist.is_initialized():
self.rank = dist.get_rank()
else:
self.rank = env_utils.get_rank()

@staticmethod
def create_replica_manager(shard_num, replica_count):
Expand Down Expand Up @@ -74,10 +80,10 @@ def __init__(self, replica_count=0) -> None:
super().__init__(replica_count)

self.backup_ranks = self._get_backup_ranks(replica_count)
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)
self._rank_shms: Dict[int, SharedMemoryHandler] = {}
if dist.is_initialized() and replica_count > 0:
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)

def _get_backup_ranks(self, replica_count):
"""
Expand Down Expand Up @@ -110,6 +116,11 @@ def backup(self, shm_handler: SharedMemoryHandler):
The rank of node in a backup group send its checkpoint shard
in the shared memory to other nodes and get the checkpoint shards
of other nodes by allgather.
Arguments:
shm_handler: The shared memory handler of the current rank on
this node.
"""
if self.replica_count == 0:
return
Expand Down Expand Up @@ -191,11 +202,8 @@ def gather(self, shm_handler: SharedMemoryHandler):
the valid shard_1 from all ranks.
Arguments:
ckpt_shards (dict): the key is the rank of checkpoint shard and the
value is the handle fo the shared memory to store the
checkpoint shard.
ckpt_metas (dict): the key is the rank of checkpoint shard and the
value is the meta dict of PyTorch checkpiont state dict.
shm_handler: The shared memory handler of the current rank on
this node.
Returns:
ByteTensor of the checkpoint shard.
Expand Down Expand Up @@ -246,10 +254,10 @@ class FullCkptReplicaManager(CkptReplicaManger):
def __init__(self, replica_count=0) -> None:
super().__init__(replica_count)
self.backup_ranks = self._get_backup_ranks()
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)
self._rank_shms: Dict[int, SharedMemoryHandler] = {}
if dist.is_initialized() and replica_count > 0:
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)

def _get_backup_ranks(self):
backup_ranks = []
Expand All @@ -271,12 +279,10 @@ def gather(self, shm_handler: SharedMemoryHandler):
node in a backup group. Firstly, the method select a source rank
in a node whose shared memory has the complete checkpoint. Then
the rank broadcast its checkpoit to other ranks.
Arguments:
ckpt_shards (dict): the key is the rank of checkpoint shard and the
value is the handle fo the shared memory to store the
checkpoint shard.
ckpt_metas (dict): the key is the rank of checkpoint shard and the
value is the meta dict of PyTorch checkpiont state dict.
shm_handler: The shared memory handler of the current rank on
this node.
Returns:
ByteTensor of the checkpoint shard.
Expand Down

0 comments on commit d9c441f

Please sign in to comment.