Skip to content

Commit

Permalink
Add a simple example of Flash Checkpoint. (intelligent-machine-learni…
Browse files Browse the repository at this point in the history
…ng#930)

* Fix the typo error.

* Add a simple demo of Flash checkpoint.

* Add a simple example of Flash Checkpoint.
  • Loading branch information
workingloong committed Jan 5, 2024
1 parent 563e7d4 commit ed70d9f
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ Please refer to the [DEVELOPMENT](docs/developer_guide.md)

## Quick Start

[An Example of Flash Checkpoint.](examples/pytorch/fcp_demo.py)

[Train a PyTorch Model on Kubernetes.](docs/tutorial/torch_elasticjob_on_k8s.md)

[Train a GPT Model on Kubernetes.](docs/tutorial/torch_ddp_nanogpt.md)
Expand Down
2 changes: 1 addition & 1 deletion docs/blogs/flash_checkpoint_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ if args.save and iteration % save_memory_interval == 0:
opt_param_scheduler, storage_type=StorageType.MEMORY,)
```

**注意**:Flash Checkpoint 的断点续存和内容热加载需要使用`dlrover-run`来启动训练脚本。如果使用其他的方式例如`torchrun`来启动,
**注意**:Flash Checkpoint 的断点续存和内存热加载需要使用`dlrover-run`来启动训练脚本。如果使用其他的方式例如`torchrun`来启动,
则只能使用异步持久化功能。`dlrover-run` 的使用方法与`torchrun`保持一致,如下所示启动单机多卡训练:

```bash
Expand Down
118 changes: 118 additions & 0 deletions examples/pytorch/fcp_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2024 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
The demo demonstrates how to use Flash Checkpoint in a DDP job.
We can start a DDP job by
```
pip install dlrover[torch] -U
dlrover-run --max_restarts=2 --nproc_per_node=2 fcp_demo.py
```
"""

import os
from datetime import timedelta

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP

from dlrover.trainer.torch.flash_checkpoint.ddp import (
DdpCheckpointer,
StorageType,
)


class Net(nn.Module):
def __init__(self, input_dim, output_dim):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_dim, 2048)
self.fc2 = nn.Linear(2048, 1024)
self.fc3 = nn.Linear(1024, 512)
self.fc4 = nn.Linear(512, 16)
self.fc5 = nn.Linear(16, output_dim)

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
return self.fc5(x)


if __name__ == "__main__":
use_cuda = torch.cuda.is_available()
if use_cuda:
dist.init_process_group("nccl", timeout=timedelta(seconds=120))
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
else:
dist.init_process_group("gloo", timeout=timedelta(seconds=120))
input_dim = 1024
batch_size = 2048

device = torch.device("cuda" if use_cuda else "cpu")
x = torch.rand(batch_size, input_dim).to(device)
y = torch.rand(batch_size, 1).to(device)

model = Net(input_dim, 1)
if use_cuda:
local_rank = int(os.environ["LOCAL_RANK"])
print(f"Running basic DDP example on local rank {local_rank}.")
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank])
else:
model = DDP(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.5)
criteria = nn.MSELoss()

checkpointer = DdpCheckpointer("/tmp/fcp_demo_ckpt")

# Load checkpoint.
state_dict = checkpointer.load_checkpoint()
if "model" in state_dict:
model.load_state_dict(state_dict["model"])
if "optimizer" in state_dict:
optimizer.load_state_dict(state_dict["optimizer"])

step = state_dict.get("step", 0)

for _ in range(1000):
step += 1
predic = model(x)
loss = criteria(predic, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
# Save checkpoint to memory.
checkpointer.save_checkpoint(
step, state_dict, storage_type=StorageType.MEMORY
)
print("step {} loss:{:.3f}".format(step, loss))
if step % 200 == 0:
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step,
}
# Save checkpoint to storage.
checkpointer.save_checkpoint(
step, state_dict, storage_type=StorageType.DISK
)
2 changes: 1 addition & 1 deletion examples/pytorch/nanogpt/ds_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
The start command on a local ndoe:
dlrover-run --nnodes=1 --max_restarts=2 --nproc_per_node=2 \
dlrover-run --max_restarts=2 --nproc_per_node=2 \
ds_train.py --n_layer 36 --n_head 20 --n_embd 1280 \
--data_dir './' --ds_config ./ds_config.json \
--epochs 50 --save_memory_interval 50 --save_storage_interval 500
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
url="https://github.com/intelligent-machine-learning/dlrover",
install_requires=install_requires,
extras_require=extra_require,
python_requires=">=3.8",
python_requires=">=3.6",
packages=find_packages(),
package_data={
"": [
Expand Down

0 comments on commit ed70d9f

Please sign in to comment.