From ed70d9f61319baa9c0de4dd9988169ef7375e09d Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Fri, 5 Jan 2024 22:05:36 +0800 Subject: [PATCH] Add a simple example of Flash Checkpoint. (#930) * Fix the typo error. * Add a simple demo of Flash checkpoint. * Add a simple example of Flash Checkpoint. --- README.md | 2 + docs/blogs/flash_checkpoint_cn.md | 2 +- examples/pytorch/fcp_demo.py | 118 +++++++++++++++++++++++++++ examples/pytorch/nanogpt/ds_train.py | 2 +- setup.py | 2 +- 5 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 examples/pytorch/fcp_demo.py diff --git a/README.md b/README.md index 413be4716..54831bf86 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,8 @@ Please refer to the [DEVELOPMENT](docs/developer_guide.md) ## Quick Start +[An Example of Flash Checkpoint.](examples/pytorch/fcp_demo.py) + [Train a PyTorch Model on Kubernetes.](docs/tutorial/torch_elasticjob_on_k8s.md) [Train a GPT Model on Kubernetes.](docs/tutorial/torch_ddp_nanogpt.md) diff --git a/docs/blogs/flash_checkpoint_cn.md b/docs/blogs/flash_checkpoint_cn.md index 864d2e42d..8125c7b5d 100644 --- a/docs/blogs/flash_checkpoint_cn.md +++ b/docs/blogs/flash_checkpoint_cn.md @@ -237,7 +237,7 @@ if args.save and iteration % save_memory_interval == 0: opt_param_scheduler, storage_type=StorageType.MEMORY,) ``` -**注意**:Flash Checkpoint 的断点续存和内容热加载需要使用`dlrover-run`来启动训练脚本。如果使用其他的方式例如`torchrun`来启动, +**注意**:Flash Checkpoint 的断点续存和内存热加载需要使用`dlrover-run`来启动训练脚本。如果使用其他的方式例如`torchrun`来启动, 则只能使用异步持久化功能。`dlrover-run` 的使用方法与`torchrun`保持一致,如下所示启动单机多卡训练: ```bash diff --git a/examples/pytorch/fcp_demo.py b/examples/pytorch/fcp_demo.py new file mode 100644 index 000000000..ee4b66cf0 --- /dev/null +++ b/examples/pytorch/fcp_demo.py @@ -0,0 +1,118 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +The demo demonstrates how to use Flash Checkpoint in a DDP job. +We can start a DDP job by + +``` +pip install dlrover[torch] -U +dlrover-run --max_restarts=2 --nproc_per_node=2 fcp_demo.py +``` +""" + +import os +from datetime import timedelta + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP + +from dlrover.trainer.torch.flash_checkpoint.ddp import ( + DdpCheckpointer, + StorageType, +) + + +class Net(nn.Module): + def __init__(self, input_dim, output_dim): + super(Net, self).__init__() + self.fc1 = nn.Linear(input_dim, 2048) + self.fc2 = nn.Linear(2048, 1024) + self.fc3 = nn.Linear(1024, 512) + self.fc4 = nn.Linear(512, 16) + self.fc5 = nn.Linear(16, output_dim) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = F.relu(self.fc3(x)) + x = F.relu(self.fc4(x)) + return self.fc5(x) + + +if __name__ == "__main__": + use_cuda = torch.cuda.is_available() + if use_cuda: + dist.init_process_group("nccl", timeout=timedelta(seconds=120)) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + else: + dist.init_process_group("gloo", timeout=timedelta(seconds=120)) + input_dim = 1024 + batch_size = 2048 + + device = torch.device("cuda" if use_cuda else "cpu") + x = torch.rand(batch_size, input_dim).to(device) + y = torch.rand(batch_size, 1).to(device) + + model = Net(input_dim, 1) + if use_cuda: + local_rank = int(os.environ["LOCAL_RANK"]) + print(f"Running basic DDP example on local rank {local_rank}.") + model = model.to(local_rank) + model = DDP(model, device_ids=[local_rank]) + else: + model = DDP(model) + optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.5) + criteria = nn.MSELoss() + + checkpointer = DdpCheckpointer("/tmp/fcp_demo_ckpt") + + # Load checkpoint. + state_dict = checkpointer.load_checkpoint() + if "model" in state_dict: + model.load_state_dict(state_dict["model"]) + if "optimizer" in state_dict: + optimizer.load_state_dict(state_dict["optimizer"]) + + step = state_dict.get("step", 0) + + for _ in range(1000): + step += 1 + predic = model(x) + loss = criteria(predic, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if step % 50 == 0: + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + # Save checkpoint to memory. + checkpointer.save_checkpoint( + step, state_dict, storage_type=StorageType.MEMORY + ) + print("step {} loss:{:.3f}".format(step, loss)) + if step % 200 == 0: + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "step": step, + } + # Save checkpoint to storage. + checkpointer.save_checkpoint( + step, state_dict, storage_type=StorageType.DISK + ) diff --git a/examples/pytorch/nanogpt/ds_train.py b/examples/pytorch/nanogpt/ds_train.py index e54d6e4fe..a379bb051 100644 --- a/examples/pytorch/nanogpt/ds_train.py +++ b/examples/pytorch/nanogpt/ds_train.py @@ -15,7 +15,7 @@ """ The start command on a local ndoe: -dlrover-run --nnodes=1 --max_restarts=2 --nproc_per_node=2 \ +dlrover-run --max_restarts=2 --nproc_per_node=2 \ ds_train.py --n_layer 36 --n_head 20 --n_embd 1280 \ --data_dir './' --ds_config ./ds_config.json \ --epochs 50 --save_memory_interval 50 --save_storage_interval 500 diff --git a/setup.py b/setup.py index 9cd1c10d1..88c2759ba 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ url="https://github.com/intelligent-machine-learning/dlrover", install_requires=install_requires, extras_require=extra_require, - python_requires=">=3.8", + python_requires=">=3.6", packages=find_packages(), package_data={ "": [