Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from intelligent-machine-learning:master #60

Merged
merged 13 commits into from
Jun 4, 2024
Prev Previous commit
Next Next commit
Fix test cases.
  • Loading branch information
workingloong committed Jun 3, 2024
commit d9c441ffdaa0fcfedef37384672a7aa3644f18fe
4 changes: 4 additions & 0 deletions dlrover/python/common/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def get_local_rank():
return int(os.getenv("LOCAL_RANK", 0))


def get_rank():
return int(os.getenv("RANK", 0))


def get_group_world_size():
return int(os.getenv("GROUP_WORLD_SIZE", 1))

Expand Down
46 changes: 26 additions & 20 deletions dlrover/trainer/torch/flash_checkpoint/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.

from abc import ABCMeta, abstractmethod
from typing import Dict
from typing import Dict, List

import torch
import torch.distributed as dist
Expand All @@ -31,9 +31,15 @@ def __init__(self, replica_count) -> None:
self.local_rank = env_utils.get_local_rank()
self.local_world_size = env_utils.get_local_world_size()
self.node_rank = env_utils.get_node_rank()
self.rank = dist.get_rank()
self.node_num = env_utils.get_node_num()
self.current_device = torch.device("cpu")
self._rank_shms: Dict[int, SharedMemoryHandler] = {}
self._backup_ranks: List[int] = []
self._backup_group = None
if dist.is_initialized():
self.rank = dist.get_rank()
else:
self.rank = env_utils.get_rank()

@staticmethod
def create_replica_manager(shard_num, replica_count):
Expand Down Expand Up @@ -74,10 +80,10 @@ def __init__(self, replica_count=0) -> None:
super().__init__(replica_count)

self.backup_ranks = self._get_backup_ranks(replica_count)
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)
self._rank_shms: Dict[int, SharedMemoryHandler] = {}
if dist.is_initialized() and replica_count > 0:
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)

def _get_backup_ranks(self, replica_count):
"""
Expand Down Expand Up @@ -110,6 +116,11 @@ def backup(self, shm_handler: SharedMemoryHandler):
The rank of node in a backup group send its checkpoint shard
in the shared memory to other nodes and get the checkpoint shards
of other nodes by allgather.

Arguments:
shm_handler: The shared memory handler of the current rank on
this node.

"""
if self.replica_count == 0:
return
Expand Down Expand Up @@ -191,11 +202,8 @@ def gather(self, shm_handler: SharedMemoryHandler):
the valid shard_1 from all ranks.

Arguments:
ckpt_shards (dict): the key is the rank of checkpoint shard and the
value is the handle fo the shared memory to store the
checkpoint shard.
ckpt_metas (dict): the key is the rank of checkpoint shard and the
value is the meta dict of PyTorch checkpiont state dict.
shm_handler: The shared memory handler of the current rank on
this node.

Returns:
ByteTensor of the checkpoint shard.
Expand Down Expand Up @@ -246,10 +254,10 @@ class FullCkptReplicaManager(CkptReplicaManger):
def __init__(self, replica_count=0) -> None:
super().__init__(replica_count)
self.backup_ranks = self._get_backup_ranks()
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)
self._rank_shms: Dict[int, SharedMemoryHandler] = {}
if dist.is_initialized() and replica_count > 0:
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)

def _get_backup_ranks(self):
backup_ranks = []
Expand All @@ -271,12 +279,10 @@ def gather(self, shm_handler: SharedMemoryHandler):
node in a backup group. Firstly, the method select a source rank
in a node whose shared memory has the complete checkpoint. Then
the rank broadcast its checkpoit to other ranks.

Arguments:
ckpt_shards (dict): the key is the rank of checkpoint shard and the
value is the handle fo the shared memory to store the
checkpoint shard.
ckpt_metas (dict): the key is the rank of checkpoint shard and the
value is the meta dict of PyTorch checkpiont state dict.
shm_handler: The shared memory handler of the current rank on
this node.

Returns:
ByteTensor of the checkpoint shard.
Expand Down