Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor arguments of ElasticDataset. #401

Merged
Prev Previous commit
Next Next commit
Save dataset checkpoint with model checkpoint into a file
  • Loading branch information
workingloong committed May 10, 2023
commit 0c51b1a5f2cd5778cc2c3588b7fe102fa22c332c
25 changes: 7 additions & 18 deletions dlrover/trainer/torch/elastic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
from abc import ABCMeta, abstractmethod

Expand All @@ -28,12 +27,6 @@ def get_rank():
return rank


def read_txt(path):
with open(path, "r") as fp:
content = fp.readlines()
return content


class ElasticDataset(Dataset, metaclass=ABCMeta):
def __init__(
self,
Expand Down Expand Up @@ -70,7 +63,6 @@ def __init__(
shuffle=shuffle,
storage_type="text",
)
self.load_checkpoint()

def __len__(self):
return self._shard_client.get_total_sample_num()
Expand All @@ -87,29 +79,26 @@ def report_batch_done(self, batch_size=None):
report the batch completion."""
self._shard_client.report_batch_done(batch_size)

def save_checkpoint(self):
def state_dict(self):
"""
Checkpoint the shards which are not completed from the
DLRover job master.
"""
rank = get_rank()
if rank != 0 or not self._checkpoint_path:
if rank != 0:
return
checkpoint = self._shard_client.get_shard_checkpoint()
with open(self._checkpoint_path, "w") as f:
f.write(checkpoint)
shards = self._shard_client.get_shard_checkpoint()
return {"shards": shards}

def load_checkpoint(self):
def load_state_dict(self, state):
"""
Restore the uncompleted shards from a checkpoint. The shard
client will send uncompleted shards to the DLRover job master.
The master will assign those shards to workers to restore training.
"""
rank = get_rank()
if rank == 0 and os.path.exists(self._checkpoint_path):
with open(self._checkpoint_path, "r") as f:
content = f.read()
self._shard_client.restore_shard_from_checkpoint(content)
if rank == 0:
self._shard_client.restore_shard_from_checkpoint(state["shards"])
dist.barrier() # Wait rank-0 to report checkpoint.
self._shard_client.set_max_shard_count()

Expand Down
34 changes: 22 additions & 12 deletions model_zoo/pytorch/mnist_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from dlrover.trainer.torch.elastic_dataset import ElasticDataset


CHEKPOINT_PATH = "model.pt"


def build_data_meta(folder):
dataset_meta = []
for d in os.listdir(folder):
Expand Down Expand Up @@ -126,6 +129,7 @@ def train(args):
use_cuda = not args.no_cuda and torch.cuda.is_available()
setup(use_cuda)
device = torch.device("cuda" if use_cuda else "cpu")
checkpoint = load_checkpoint(CHEKPOINT_PATH)

train_dataset = ElasticMnistDataset(
path=args.training_data,
Expand All @@ -134,6 +138,8 @@ def train(args):
shuffle=args.shuffle,
checkpoint_path="./train_dataset.ckpt",
)
if checkpoint:
train_dataset.load_state_dict(checkpoint["train_shards"])
train_loader = DataLoader(
dataset=train_dataset, batch_size=args.batch_size, num_workers=2
)
Expand Down Expand Up @@ -165,7 +171,9 @@ def train(args):

elastic_trainer = ElasticTrainer(model, train_loader)
optimizer, scheduler = elastic_trainer.prepare(optimizer, scheduler)
load_checkpoint(model, optimizer)
if checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

for _, (data, target) in enumerate(train_loader):
model.train()
Expand All @@ -185,29 +193,31 @@ def train(args):
elastic_trainer.num_steps > 0
and elastic_trainer.num_steps % 200 == 0
):
save_checkpoint(model, optimizer, train_dataset)
save_checkpoint(
CHEKPOINT_PATH, model, optimizer, train_dataset
)
if (
elastic_trainer.num_steps > 0
and elastic_trainer.num_steps % 10000 == 0
):
test(model, device, test_loader)


def save_checkpoint(model, optimizer, dataset: ElasticDataset):
model_checkpoint = {
def save_checkpoint(path, model, optimizer, dataset: ElasticDataset):
print("Save checkpoint.")
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"train_shards": dataset.state_dict(),
}
torch.save(model_checkpoint, "model.pt")
dataset.save_checkpoint()
torch.save(checkpoint, path)


def load_checkpoint(model, optimizer):
if not os.path.exists("model.pt"):
return
checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
def load_checkpoint(path):
if not os.path.exists(path):
return {}
checkpoint = torch.load(path)
return checkpoint


def test(model, device, test_loader):
Expand Down