Skip to content

Commit

Permalink
Merge pull request intelligent-machine-learning#1153 from workingloon…
Browse files Browse the repository at this point in the history
…g/backup-ckpt

Restore checkpoint replica from the shared memory of other nodes.
  • Loading branch information
samplise committed Jun 3, 2024
2 parents 9104f80 + 65550ba commit 86dcb5d
Show file tree
Hide file tree
Showing 12 changed files with 242 additions and 106 deletions.
4 changes: 4 additions & 0 deletions dlrover/python/common/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def get_local_rank():
return int(os.getenv("LOCAL_RANK", 0))


def get_rank():
return int(os.getenv("RANK", 0))


def get_group_world_size():
return int(os.getenv("GROUP_WORLD_SIZE", 1))

Expand Down
5 changes: 4 additions & 1 deletion dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,8 @@ def _check_shard_step_consistence(self, step, timeout=15):
ckpt_config = self._shm_handlers[i].get_checkpoint_config(
default_config
)
steps.append(ckpt_config.step)
if ckpt_config.step > 0:
steps.append(ckpt_config.step)
if all(i == step for i in steps):
return True
time.sleep(1)
Expand Down Expand Up @@ -819,6 +820,8 @@ def save_step_checkpoint(self, step: int):
ckpt_config = self._shm_handlers[i].get_checkpoint_config(
default_config
)
if ckpt_config.step == 0:
continue
future: Future = self._executor.submit(
self._save_shard,
step,
Expand Down
30 changes: 15 additions & 15 deletions dlrover/trainer/tests/torch/checkpoint_backup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
CheckpointConfig,
SharedMemoryHandler,
)
from dlrover.trainer.torch.flash_checkpoint.ckpt_backup import (
FullCkptBackupManager,
ShardCkptBackupManager,
from dlrover.trainer.torch.flash_checkpoint.replica import (
FullCkptReplicaManager,
ShardCkptReplicaManager,
)

CHECKPOINT_DIR = "checkpoint"
Expand All @@ -53,6 +53,8 @@ def cleanup():
def run_checkpoint_backup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["LOCAL_RANK"] = str(rank)
os.environ["LOCAL_WORLD_SIZE"] = "1"

# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
Expand All @@ -70,11 +72,9 @@ def run_checkpoint_backup(rank, world_size):
shm_hanlder.save_state_dict(state_dict)

with mock.patch.object(
ShardCkptBackupManager, "_get_backup_ranks", return_value=[0, 1]
ShardCkptReplicaManager, "_get_backup_ranks", return_value=[0, 1]
):
back_manager = ShardCkptBackupManager(
local_rank=rank, local_world_size=1, backup_group_size=2
)
back_manager = ShardCkptReplicaManager(replica_count=2)
back_manager.backup_ranks = list(range(world_size))
back_manager.backup(shm_hanlder)
if rank == 0:
Expand All @@ -87,12 +87,10 @@ def run_checkpoint_backup(rank, world_size):
raise ValueError("Test Failed!")

with mock.patch.object(
FullCkptBackupManager, "_get_backup_ranks", return_value=[0, 1]
FullCkptReplicaManager, "_get_backup_ranks", return_value=[0, 1]
):
back_manager = FullCkptBackupManager(
local_rank=rank, local_world_size=1
)
shm_tensor, _ = back_manager._gather_owner_checkpoint(shm_hanlder)
back_manager = FullCkptReplicaManager(replica_count=1)
shm_tensor, _ = back_manager.gather(shm_hanlder)
if rank == 0 and shm_tensor.numel() != 1632:
raise ValueError("Test Failed!")
cleanup()
Expand All @@ -104,13 +102,15 @@ def setUp(self) -> None:

@mock.patch("torch.distributed.new_group")
@mock.patch("torch.distributed.get_rank")
def test_get_backup_ranks(self, mock_new_group, mock_get_rank):
def test_get_backup_ranks(self, _, mock_get_rank):
mock_get_rank.return_value = 1
shard_manager = ShardCkptBackupManager(0, 8)
os.environ["LOCAL_RANK"] = "0"
os.environ["LOCAL_WORLD_SIZE"] = "8"
shard_manager = ShardCkptReplicaManager(replica_count=2)
self.assertListEqual(shard_manager.backup_ranks, [0, 8])

os.environ["NODE_NUM"] = "4"
shard_manager = FullCkptBackupManager(0, 8)
shard_manager = FullCkptReplicaManager(replica_count=2)
self.assertListEqual(shard_manager.backup_ranks, [0, 8, 16, 24])

def test_backup_checkpoint(self):
Expand Down
24 changes: 24 additions & 0 deletions dlrover/trainer/tests/torch/checkpoint_egine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
import unittest
from pathlib import Path
from unittest import mock

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -49,6 +50,9 @@
from dlrover.trainer.torch.flash_checkpoint.megatron_engine import (
MegatronCheckpointEngine,
)
from dlrover.trainer.torch.flash_checkpoint.replica import (
FullCkptReplicaManager,
)


def run_rank_sync(rank, ranks, world_size, master_port):
Expand Down Expand Up @@ -247,6 +251,26 @@ def test_deepspeed_engine(self):
restored_step = int(f.read())
self.assertEqual(restored_step, step)

@mock.patch("torch.distributed.barrier")
def test_restore_memory_from_replica(self, mock_barrier):
buffer = memoryview(b"123456789")
meta = {"step": 100, "name": "test-weights"}
storage = PosixDiskStorage()
with tempfile.TemporaryDirectory() as tmpdir:
saving_engine = SimpleShardingCheckpointEngine(
tmpdir, storage, replica_count=1
)
saving_engine._local_rank = 7
with mock.patch.object(
FullCkptReplicaManager,
"gather",
return_value=(torch.ByteTensor(buffer), meta),
):
saving_engine._restore_memory_from_replica()
shm_metadata = saving_engine._shm_handler.metadata.get()
self.assertDictEqual(shm_metadata, meta)
mock_barrier.assert_called()


class CheckpointEngineTest(unittest.TestCase):
def setUp(self):
Expand Down
10 changes: 5 additions & 5 deletions dlrover/trainer/tests/torch/deepspeed_ckpt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def zero_optimization_stage(self):

def save_checkpoint(self, save_dir, tag, client_state, save_latest):
model_sd = self.model.state_dict()
model_path = os.path.join(save_dir, tag, "model_states.pt")
model_path = os.path.join(save_dir, str(tag), "model_states.pt")
torch.save(model_sd, model_path)
optimizer_sd = self.optimizer.state_dict()
optim_path = os.path.join(save_dir, tag, "optim_states.pt")
optim_path = os.path.join(save_dir, str(tag), "optim_states.pt")
torch.save(optimizer_sd, optim_path)

def load_checkpoint(
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_save_load(self):
engine = MockDeepSpeedEngine(model, optimizer)
checkpointer = DeepSpeedCheckpointer(engine, tmpdirname)
checkpointer.save_checkpoint(
tmpdirname, str(step), storage_type=StorageType.MEMORY
tmpdirname, step, storage_type=StorageType.MEMORY
)
shm_handler = checkpointer._async_save_engine._shm_handler
self.assertFalse(shm_handler.no_checkpint_state())
Expand All @@ -122,13 +122,13 @@ def test_save_load(self):
checkpointer._async_save_engine._shm_handler.metadata.get()
)
ds_ckpt_config = tensor_meta["_DLORVER_CKPT_CONFIG"]
self.assertEqual(ds_ckpt_config.step, str(step))
self.assertEqual(ds_ckpt_config.step, step)
self.assertIsNotNone(tensor_meta["model_states"])
tracer_file = os.path.join(tmpdirname, "latest")
self.assertFalse(os.path.exists(tracer_file))

checkpointer.save_checkpoint(
tmpdirname, str(step), storage_type=StorageType.DISK
tmpdirname, step, storage_type=StorageType.DISK
)
# Wait asynchronously saving.
start = time.time()
Expand Down
3 changes: 3 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class DdpCheckpointer(Checkpointer):
ranks save checkpoints. The node rank 0 will skip the checkpoint
if some ranks do not finish saving checkpoint in the save_timeout
after the node rank 0 finishes saving checkpoint.
replica_count(int): the number of checkpoint replica in other nodes.
Examples::
>>> checkpointer = DdpCheckpointer(
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
comm_backend="",
deletion_strategy=None,
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
replica_count=0,
):
self.checkpoint_dir = checkpoint_dir
if dist.is_initialized():
Expand All @@ -89,6 +91,7 @@ def __init__(
global_shard_num=global_shard_num,
comm_backend=comm_backend,
save_timeout=save_timeout,
replica_count=replica_count,
)

def save_checkpoint(
Expand Down
33 changes: 32 additions & 1 deletion dlrover/trainer/torch/flash_checkpoint/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ClassMeta,
SharedMemoryHandler,
)
from dlrover.trainer.torch.flash_checkpoint.replica import CkptReplicaManger


def _local_rank0_log(local_rank, message):
Expand Down Expand Up @@ -160,14 +161,15 @@ def __init__(
storage: CheckpointStorage,
comm_backend: str = "",
save_timeout: int = CheckpointConstant.SAVE_TIMEOUT,
replica_count=0,
):
if not self.saver_proc:
self.saver_proc = start_saver_process()

self.checkpoint_dir = checkpoint_dir
self.storage = storage
self._save_timeout = save_timeout
self._local_rank = int(os.getenv("LOCAL_RANK", 0))
self._local_rank = env_utils.get_local_rank()
self._cached_step = 0
self._restart_count = env_utils.get_torch_restart_count()
# queue for agent to save to storage, only lock rank 0 needs the queue.
Expand Down Expand Up @@ -197,6 +199,10 @@ def __init__(
self._init_sync_group(comm_backend)
self._notify_agent_to_create_saver()
self._update_saver_config()
shard_num = self.get_global_shard_num()
self._replica_manager = CkptReplicaManger.create_replica_manager(
shard_num, replica_count
)

def _init_sync_group(self, comm_backend):
if not dist.is_initialized():
Expand Down Expand Up @@ -320,9 +326,14 @@ def save_state_dict_to_memory(self, state_dict, conf: CheckpointConfig):
if acquired:
self._shm_lock.release()
self._cached_step = conf.step
self._replica_manager.backup(self._shm_handler)
return True

def get_state_dict_from_memory(self):
"""
Restore the checkpoint state dict from the shared memory.
"""
self._restore_memory_from_replica()
state_dict = {}
default_config = CheckpointConfig()
config = self._shm_handler.get_checkpoint_config(default_config)
Expand All @@ -335,6 +346,26 @@ def get_state_dict_from_memory(self):
logger.info(f"Load checkpoint at step {config.step} from memory.")
return config.step, state_dict

def _restore_memory_from_replica(self):
if not self._replica_manager.has_replica():
return
self._shm_handler.init_shared_memory()
byte_tensor, meta = self._replica_manager.gather(self._shm_handler)
print(byte_tensor, meta, self._shm_handler.shared_memory)
if (
byte_tensor is not None
and meta
and not self._shm_handler.shared_memory
):
shm_size = byte_tensor.size()[0]
self._shm_handler.init_shared_memory(create=True, size=shm_size)
self._shm_handler.metadata.set(meta)
logger.info(
f"Restore the checkpoint shard with size = {shm_size}"
"from the replica in the memory of the alive node."
)
dist.barrier()

@abstractmethod
def get_saving_ranks(self):
pass
Expand Down
9 changes: 8 additions & 1 deletion dlrover/trainer/torch/flash_checkpoint/full_ckpt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,20 @@ def __init__(
global_shard_num=1,
comm_backend="",
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
replica_count=0,
):
if global_shard_num < local_shard_num:
global_shard_num = local_shard_num
logger.info(f"Set global_shard_num to {local_shard_num}.")
self._local_shard_num = local_shard_num
self._global_shard_num = global_shard_num
super().__init__(checkpoint_dir, storage, comm_backend, save_timeout)
super().__init__(
checkpoint_dir,
storage,
comm_backend,
save_timeout,
replica_count=replica_count,
)

def get_saving_ranks(self):
"""
Expand Down
7 changes: 7 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
storage=None,
comm_backend="",
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
replica_count=0,
):
self.state_dict = {}
self.paths = {}
Expand All @@ -68,6 +69,7 @@ def __init__(
storage=self.storage,
comm_backend=comm_backend,
save_timeout=save_timeout,
replica_count=replica_count,
)

def save(self, state_dict, path: str):
Expand Down Expand Up @@ -144,6 +146,7 @@ def save_checkpoint(
storage=None,
comm_backend="",
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
replica_count=0,
):
"""
Synchronously save the the checkpointing state dict into the CPU memory.
Expand All @@ -154,13 +157,15 @@ def save_checkpoint(
use a PosixStorage instance if the storage is not defined.
comm_backend (str): the backend to synchronize when saving the
checkpoint to the memory.
replica_count(int): the number of checkpoint replica in other nodes.
"""
args = get_args()
saver = MegatronCheckpointer.singleton_instance(
args.save,
storage=storage,
comm_backend=comm_backend,
save_timeout=save_timeout,
replica_count=replica_count,
)
sig = inspect.signature(megatron_save)
if storage_type == StorageType.MEMORY:
Expand Down Expand Up @@ -219,6 +224,7 @@ def load_checkpoint(
storage=None,
comm_backend="",
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
replica_count=0,
):
"""Load the checkpointing state dict. The method firstly
load the state dict from the CPU memory and then from the storage.
Expand All @@ -231,6 +237,7 @@ def load_checkpoint(
storge=storage,
comm_backend=comm_backend,
save_timeout=save_timeout,
replica_count=replica_count,
)
torch.load = checkpointer.load
iteration = megatron_load(
Expand Down
9 changes: 8 additions & 1 deletion dlrover/trainer/torch/flash_checkpoint/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
storage,
comm_backend="",
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
replica_count=0,
):
if dist.is_initialized():
try:
Expand All @@ -60,7 +61,13 @@ def __init__(
self._pp_world_size = 1
self._tp_world_size = 1

super().__init__(checkpoint_dir, storage, comm_backend, save_timeout)
super().__init__(
checkpoint_dir,
storage,
comm_backend,
save_timeout,
replica_count=replica_count,
)

def get_saving_ranks(self):
"""
Expand Down
Loading

0 comments on commit 86dcb5d

Please sign in to comment.