Skip to content

Commit

Permalink
The restarted node can acquire the full checkpoint in the peer node. (#…
Browse files Browse the repository at this point in the history
…1145)

* The restarted node can acquire the checkpoint in the peer node.

* Add test cases.

* Fix by comments.

* Fix to call super init.
  • Loading branch information
workingloong committed May 30, 2024
1 parent 93df4af commit aa5e6b6
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 71 deletions.
2 changes: 1 addition & 1 deletion dlrover/python/master/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(args):

worker = job_args.node_args[NodeType.WORKER].group_resource
worker.count = args.node_num
master = LocalJobMaster(_dlrover_context.master_port, job_args)
master = LocalJobMaster(args.port, job_args)
else:
from dlrover.python.master.dist_master import DistributedJobMaster

Expand Down
49 changes: 31 additions & 18 deletions dlrover/trainer/tests/torch/checkpoint_backup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
SharedMemoryHandler,
)
from dlrover.trainer.torch.flash_checkpoint.ckpt_backup import (
ZeroCkptBackupManager,
get_backup_ranks,
FullCkptBackupManager,
ShardCkptBackupManager,
)

CHECKPOINT_DIR = "checkpoint"
Expand Down Expand Up @@ -69,18 +69,30 @@ def run_checkpoint_backup(rank, world_size):
}
shm_hanlder.save_state_dict(state_dict)

mock_func_path = (
"dlrover.trainer.torch.flash_checkpoint.ckpt_backup.get_backup_ranks"
)
with mock.patch(mock_func_path, return_value=[0, 1]):
back_manager = ZeroCkptBackupManager(
with mock.patch.object(
ShardCkptBackupManager, "_get_backup_ranks", return_value=[0, 1]
):
back_manager = ShardCkptBackupManager(
local_rank=rank, local_world_size=1, backup_group_size=2
)
back_manager.backup_ranks = list(range(world_size))
back_manager.backup(shm_hanlder)
peer_shm_handler = SharedMemoryHandler(2)
shm_hanlders = [shm_hanlder, peer_shm_handler]
shm_tensor, meta = back_manager._gather_owner_checkpoint(shm_hanlders)
if rank == 0:
shm_hanlders = [shm_hanlder, shm_hanlder]
else:
peer_shm_handler = SharedMemoryHandler(2)
shm_hanlders = [peer_shm_handler, peer_shm_handler]
shm_tensor, _ = back_manager._gather_owner_checkpoint(shm_hanlders)
if rank == 0 and shm_tensor.numel() != 1632:
raise ValueError("Test Failed!")

with mock.patch.object(
FullCkptBackupManager, "_get_backup_ranks", return_value=[0, 1]
):
back_manager = FullCkptBackupManager(
local_rank=rank, local_world_size=1
)
shm_tensor, _ = back_manager._gather_owner_checkpoint(shm_hanlder)
if rank == 0 and shm_tensor.numel() != 1632:
raise ValueError("Test Failed!")
cleanup()
Expand All @@ -90,15 +102,16 @@ class CheckpointBackupTest(unittest.TestCase):
def setUp(self) -> None:
shutil.rmtree(SOCKET_TMP_DIR, ignore_errors=True)

def test_get_backup_ranks(self):
@mock.patch("torch.distributed.new_group")
@mock.patch("torch.distributed.get_rank")
def test_get_backup_ranks(self, mock_new_group, mock_get_rank):
mock_get_rank.return_value = 1
shard_manager = ShardCkptBackupManager(0, 8)
self.assertListEqual(shard_manager.backup_ranks, [0, 8])

ranks = get_backup_ranks(
node_rank=5,
local_rank=1,
local_world_size=8,
group_size=3,
)
self.assertListEqual(ranks, [25, 33, 41])
os.environ["NODE_NUM"] = "4"
shard_manager = FullCkptBackupManager(0, 8)
self.assertListEqual(shard_manager.backup_ranks, [0, 8, 16, 24])

def test_backup_checkpoint(self):
world_size = 2
Expand Down
184 changes: 132 additions & 52 deletions dlrover/trainer/torch/flash_checkpoint/ckpt_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.

from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List
from typing import Dict, List

import torch
import torch.distributed as dist
Expand All @@ -25,38 +25,17 @@
)


def get_backup_ranks(node_rank, local_rank, local_world_size, group_size):
"""
Get the ranks to backup checkpoint. Assuming each group has 3 nodes
(group_size=3) and each node has 2 ranks. The backup ranks of local
rank in each node are:
local rank 0: {0, 8 ,16}
local rank 1: {1, 9, 17}
Arguments:
node_rank: the rank of node in the job.
local_rank: the local rank in a node.
local_world_size: the number of local ranks in a node.
group_size: the number of nodes in each backup group.
Returns:
A list of ranks.
"""
backup_ranks = []

group_index = node_rank // group_size
for i in range(group_size):
node_rank = group_index * group_size + i
rank = node_rank * local_world_size + local_rank
backup_ranks.append(rank)
return backup_ranks


class BackupManger(metaclass=ABCMeta):
def __init__(self, local_rank, local_world_size) -> None:
self.local_rank = local_rank
self.local_world_size = local_world_size
self.node_rank = env_utils.get_node_rank()
self.rank = dist.get_rank()
self.node_num = env_utils.get_node_num()
self.current_device = torch.device("cpu")

@abstractmethod
def backup(
self, shm_handler: SharedMemoryHandler, ckpt_meta: Dict[Any, Any]
):
def backup(self, shm_handler: SharedMemoryHandler):
"""
The nodes in a backup group back up the checkpoint of each other.
"""
Expand All @@ -71,7 +50,7 @@ def gather(self):
pass


class ZeroCkptBackupManager(object):
class ShardCkptBackupManager(BackupManger):
"""
The manager will select a rank of another node to backup the checkpoint
of the current rank.
Expand All @@ -80,23 +59,40 @@ class ZeroCkptBackupManager(object):
def __init__(
self, local_rank, local_world_size, backup_group_size=2
) -> None:
self.local_rank = local_rank
self.local_world_size = local_world_size
self.node_rank = env_utils.get_node_rank()
self.rank = dist.get_rank()
self.current_device = torch.device("cpu")
super().__init__(local_rank, local_world_size)

self.backup_ranks = get_backup_ranks(
self.node_rank,
self.local_rank,
self.local_world_size,
backup_group_size,
)
self.backup_ranks = self._get_backup_ranks(backup_group_size)
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)
self._rank_shms: Dict[int, SharedMemoryHandler] = {}

def _get_backup_ranks(self, backup_group_size):
"""
Get the ranks to backup checkpoint. Assuming each group has 3 nodes
(group_size=3) and each node has 2 ranks. The backup ranks of local
rank in each node are:
local rank 0: {0, 8 ,16}
local rank 1: {1, 9, 17}
Arguments:
node_rank: the rank of node in the job.
local_rank: the local rank in a node.
local_world_size: the number of local ranks in a node.
group_size: the number of nodes in each backup group.
Returns:
A list of ranks.
"""
backup_ranks = []

group_index = self.node_rank // backup_group_size
for i in range(backup_group_size):
node_rank = group_index * backup_group_size + i
rank = node_rank * self.local_world_size + self.local_rank
backup_ranks.append(rank)
return backup_ranks

def backup(self, shm_handler: SharedMemoryHandler):
"""
The rank of node in a backup group send its checkpoint shard
Expand Down Expand Up @@ -127,13 +123,16 @@ def _gather_peer_ckpt(self, buffer, meta_data):
# Allgather tensor sizes
dist.all_gather(shm_size_list, local_size, group=self._backup_group)

output_tensors = []
for tensor_size in shm_size_list:
output_tensor = torch.empty(
tensor_size, dtype=torch.uint8, device=self.current_device
)
output_tensors.append(output_tensor)
max_tensor_size = int(max(shm_size_list))
# Resize tensor to max size across all ranks.
byte_tensor.resize_(max_tensor_size)

output_tensors = [
torch.empty(
max_tensor_size, dtype=torch.uint8, device=self.current_device
)
for _ in shm_size_list
]
dist.all_gather(output_tensors, byte_tensor, group=self._backup_group)

output_meta_objs = [None for _ in range(group_size)]
Expand Down Expand Up @@ -200,14 +199,14 @@ def _gather_owner_checkpoint(
):
ckpt_shm_tensor = None
ckpt_meta = {}
for rank in self.backup_ranks:
shm_handler = shm_handlers[rank]
for i, rank in enumerate(self.backup_ranks):
shm_handler = shm_handlers[i]
if shm_handler.shared_memory:
assert shm_handler.shared_memory is not None
buffer = shm_handler.shared_memory.buf
meta_data = shm_handlers[rank].metadata.get()
else:
buffer = [1]
buffer = memoryview(b"")
meta_data = {}
shm_tensors, ckpt_metas = self._gather_peer_ckpt(buffer, meta_data)
if rank != self.rank:
Expand All @@ -218,3 +217,84 @@ def _gather_owner_checkpoint(
ckpt_meta = meta
break
return ckpt_shm_tensor, ckpt_meta


class FullCkptBackupManager(BackupManger):
"""
The node does not need to backup checkpoint if each rank has
the full checkpoint. The manager can select one rank which
has the full checkpoint to broadcast the checkpiont in the
shared memory to other ranks.
"""

def __init__(self, local_rank, local_world_size) -> None:
super().__init__(local_rank, local_world_size)
self.backup_ranks = self._get_backup_ranks()
self._backup_group = dist.new_group(
backend="gloo", ranks=self.backup_ranks
)
self._rank_shms: Dict[int, SharedMemoryHandler] = {}

def _get_backup_ranks(self):
backup_ranks = []
for node_rank in range(self.node_num):
rank = node_rank * self.local_world_size
backup_ranks.append(rank)
return backup_ranks

def backup(self, shm_handler: SharedMemoryHandler):
"""
Each rank in the node has full model checkpoint. The checkpoints
in the shared memory of each node are backups of the checkpoint.
"""
pass

def gather(self):
"""
The method gathers the checkpoint shard from the memory of the peer
node in a backup group. Firstly, the method select a source rank
in a node whose shared memory has the complete checkpoint. Then
the rank broadcast its checkpoit to other ranks.
Arguments:
ckpt_shards (dict): the key is the rank of checkpoint shard and the
value is the handle fo the shared memory to store the
checkpoint shard.
ckpt_metas (dict): the key is the rank of checkpoint shard and the
value is the meta dict of PyTorch checkpiont state dict.
Returns:
ByteTensor of the checkpoint shard.
A dict of checkpoint shard meta data.
"""

shm_handlers = {}
if self.rank not in self.backup_ranks:
return

shm_handler = SharedMemoryHandler(local_rank=self.local_rank)
shm_handler.init_shared_memory()
shm_tensor, meta = self._gather_owner_checkpoint(shm_handlers)
return shm_tensor, meta

def _gather_owner_checkpoint(self, shm_handler: SharedMemoryHandler):
flag = torch.tensor([1], dtype=torch.int8)
if not shm_handler.shared_memory:
flag = torch.tensor([0], dtype=torch.int8)
output_tensors = [
torch.empty(1, dtype=torch.int8) for _ in self.backup_ranks
]
dist.all_gather(output_tensors, flag)
src_rank = self.backup_ranks[0]
for i, flag in enumerate(output_tensors):
if flag == 1:
src_rank = self.backup_ranks[i]
break
if shm_handler.shared_memory:
byte_tensor = torch.ByteTensor(shm_handler.shared_memory.buf)
else:
byte_tensor = torch.ByteTensor([])
dist.broadcast(byte_tensor, src=src_rank)
ckpt_meta = shm_handler.metadata.get()
ckpt_metas = [ckpt_meta]
dist.broadcast_object_list(ckpt_metas, src=src_rank)
return byte_tensor, ckpt_metas[0]

0 comments on commit aa5e6b6

Please sign in to comment.