Skip to content

Commit

Permalink
Refactor arguments of ElasticDataset. (#401)
Browse files Browse the repository at this point in the history
* Refactor arguments of ElasticDataset

* Save dataset checkpoint with model checkpoint into a file

* Remove unused checkpoint_path

* Format codes

* Set the default value to the number of mini-batch per shard

* Polish the docstring

* Add an example in the doc string

* Polish the docstring

* Add docstring to build data meta

* Simplify docstring

* Format codes

* Checkout whether shards is empty
  • Loading branch information
workingloong committed May 11, 2023
1 parent e2027b3 commit 6149b50
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 74 deletions.
4 changes: 2 additions & 2 deletions dlrover/examples/torch_mnist_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ spec:
--max_restarts=3 \
--rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT:2222 \
model_zoo/pytorch/mnist_cnn.py \
--training_data /data/mnist_png/training/elastic_ds.txt \
--validation_data /data/mnist_png/testing/elastic_ds.txt"
--training_data /data/mnist_png/training/ \
--validation_data /data/mnist_png/testing/ "
dlrover-master:
template:
spec:
Expand Down
4 changes: 2 additions & 2 deletions dlrover/examples/torch_mnist_master_backend_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ spec:
&& torchrun --nnodes=1:$WORKER_NUM --nproc_per_node=1
--max_restarts=3 --rdzv_backend=dlrover-master \
model_zoo/pytorch/mnist_cnn.py \
--training_data /data/mnist_png/training/elastic_ds.txt \
--validation_data /data/mnist_png/testing/elastic_ds.txt"
--training_data /data/mnist_png/training/ \
--validation_data /data/mnist_png/testing/"
resources:
limits:
cpu: "1"
Expand Down
86 changes: 48 additions & 38 deletions dlrover/trainer/torch/elastic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
from abc import ABCMeta, abstractmethod
from typing import Dict

import torch.distributed as dist
from torch.utils.data import Dataset
Expand All @@ -27,41 +28,51 @@ def get_rank():
return rank


def read_txt(path):
with open(path, "r") as fp:
content = fp.readlines()
return content


class ElasticDataset(Dataset, metaclass=ABCMeta):
def __init__(self, path, batch_size, epochs, shuffle, checkpoint_path=""):
"""Using ElasticDataset, the node can read samples without
duplicates with other nodes in an epoch. DLRover master
will dispatch the index of sample in a dataset to one node.
Args:
path: str, the path of dataset meta file. For example, if the image
is stored in a folder. The meta file should be a
text file where each line is the absolute path of a image.
batch_size: int, the size of batch samples to compute gradients
in a trainer process.
epochs: int, the number of epoch.
shuffle: bool, whether to shuffle samples in the dataset.
checkpoint_path: the path to save the checkpoint of shards
int the dataset.
"""
self.lines = read_txt(path)
self.dataset_size = len(self.lines)
self._checkpoint_path = checkpoint_path
"""Using ElasticDataset, the node can read samples without
duplicates with other nodes in an epoch. DLRover master
will dispatch the index of sample in a dataset to one node.
Users need to implement the read_sample to read data by the
sample index.
Example:
>>> dataset = ElasticDataset(1000, 32, 2, True)
>>> state = dataset.state_dict() # checkpoint
>>> dataset.load_state_dict(state)
>>> data_loader = DataLoader(
>>> dataset=dataset, batch_size=args.batch_size, num_workers=2,
>>> )
Args:
dataset_size: the number of samples in the dataset.
batch_size: int, the size of batch samples to compute gradients
in a trainer process.
epochs: int, the number of epoch.
shuffle: bool, whether to shuffle samples in the dataset.
name: str, the name of dataset.
"""

def __init__(
self,
dataset_size,
batch_size,
epochs,
shuffle,
name=None,
num_minibatches_per_shard=2,
):
self.dataset_size = dataset_size
if not name:
name = "dlrover-ds-" + str(time.time())
self._shard_client = IndexShardingClient(
dataset_name=path,
dataset_name=name,
batch_size=batch_size,
num_epochs=epochs,
dataset_size=self.dataset_size,
shuffle=shuffle,
storage_type="text",
num_minibatches_per_shard=num_minibatches_per_shard,
)
self.load_checkpoint()

def __len__(self):
return self._shard_client.get_total_sample_num()
Expand All @@ -78,29 +89,28 @@ def report_batch_done(self, batch_size=None):
report the batch completion."""
self._shard_client.report_batch_done(batch_size)

def save_checkpoint(self):
def state_dict(self):
"""
Checkpoint the shards which are not completed from the
DLRover job master.
"""
rank = get_rank()
if rank != 0 or not self._checkpoint_path:
if rank != 0:
return
checkpoint = self._shard_client.get_shard_checkpoint()
with open(self._checkpoint_path, "w") as f:
f.write(checkpoint)
shards = self._shard_client.get_shard_checkpoint()
return {"shards": shards}

def load_checkpoint(self):
def load_state_dict(self, state: Dict[str, str]):
"""
Restore the uncompleted shards from a checkpoint. The shard
client will send uncompleted shards to the DLRover job master.
The master will assign those shards to workers to restore training.
"""
rank = get_rank()
if rank == 0 and os.path.exists(self._checkpoint_path):
with open(self._checkpoint_path, "r") as f:
content = f.read()
self._shard_client.restore_shard_from_checkpoint(content)
if rank == 0 and state:
shards = state.get("shards", "")
if shards:
self._shard_client.restore_shard_from_checkpoint(shards)
dist.barrier() # Wait rank-0 to report checkpoint.
self._shard_client.set_max_shard_count()

Expand Down
88 changes: 56 additions & 32 deletions model_zoo/pytorch/mnist_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,52 @@
from dlrover.trainer.torch.elastic import ElasticTrainer
from dlrover.trainer.torch.elastic_dataset import ElasticDataset

CHEKPOINT_PATH = "model.pt"


def build_data_meta(folder):
"""Save the path of sample into a list and we can get the path
by the index of sample.
The directory structure of mnist is
|- root
|- 0
|- 1.png
|- 21.png
|- 1
|- 3.png
|- 6.png
the meta is a list [
("root/0/1.png", 0),
("root/0/21.png", 0),
("root/3.png", 1),
("root/1/6.png", 1),
]
"""
dataset_meta = []
for d in os.listdir(folder):
dir_path = os.path.join(folder, d)
if os.path.isdir(dir_path):
for f in os.listdir(dir_path):
if f.endswith(".png"):
file_path = os.path.join(dir_path, f)
dataset_meta.append([file_path, d])
return dataset_meta


class ElasticMnistDataset(ElasticDataset):
def __init__(self, path, batch_size, epochs, shuffle, checkpoint_path):
"""The dataset supports elastic training.
Args:
path: str, the path of dataset meta file. For example, if the image
is stored in a folder. The meta file should be a
text file where each line is the absolute path of a image.
batch_size: int, the size of batch samples to compute gradients
in a trainer process.
epochs: int, the number of epoch.
shuffle: bool, whether to shuffle samples in the dataset.
checkpoint_path: the path to save the checkpoint of shards
int the dataset.
"""
def __init__(self, path, batch_size, epochs, shuffle):
"""The dataset supports elastic training."""
self.data_meta = build_data_meta(path)
super(ElasticMnistDataset, self).__init__(
path,
len(self.data_meta),
batch_size,
epochs,
shuffle,
checkpoint_path,
)

def read_sample(self, index):
image_meta = self.lines[index]
image_path, label = image_meta.split(",")
image_path, label = self.data_meta[index]
image = cv2.imread(image_path)
image = np.array(image / 255.0, np.float32)
image = image.reshape(3, 28, 28)
Expand Down Expand Up @@ -116,14 +135,16 @@ def train(args):
use_cuda = not args.no_cuda and torch.cuda.is_available()
setup(use_cuda)
device = torch.device("cuda" if use_cuda else "cpu")
checkpoint = load_checkpoint(CHEKPOINT_PATH)

train_dataset = ElasticMnistDataset(
path=args.training_data,
batch_size=args.batch_size,
epochs=args.num_epochs,
shuffle=args.shuffle,
checkpoint_path="./train_dataset.ckpt",
)
if checkpoint:
train_dataset.load_state_dict(checkpoint.get("train_shards", {}))
train_loader = DataLoader(
dataset=train_dataset, batch_size=args.batch_size, num_workers=2
)
Expand All @@ -133,7 +154,6 @@ def train(args):
batch_size=args.batch_size,
epochs=1,
shuffle=False,
checkpoint_path="./test_dataset.ckpt",
)
test_loader = DataLoader(
dataset=test_dataset, batch_size=args.batch_size, num_workers=2
Expand All @@ -155,7 +175,9 @@ def train(args):

elastic_trainer = ElasticTrainer(model, train_loader)
optimizer, scheduler = elastic_trainer.prepare(optimizer, scheduler)
load_checkpoint(model, optimizer)
if checkpoint:
model.load_state_dict(checkpoint.get("model_state_dict", {}))
optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {}))

for _, (data, target) in enumerate(train_loader):
model.train()
Expand All @@ -175,29 +197,31 @@ def train(args):
elastic_trainer.num_steps > 0
and elastic_trainer.num_steps % 200 == 0
):
save_checkpoint(model, optimizer, train_dataset)
save_checkpoint(
CHEKPOINT_PATH, model, optimizer, train_dataset
)
if (
elastic_trainer.num_steps > 0
and elastic_trainer.num_steps % 10000 == 0
):
test(model, device, test_loader)


def save_checkpoint(model, optimizer, dataset: ElasticDataset):
model_checkpoint = {
def save_checkpoint(path, model, optimizer, dataset: ElasticDataset):
print("Save checkpoint.")
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"train_shards": dataset.state_dict(),
}
torch.save(model_checkpoint, "model.pt")
dataset.save_checkpoint()
torch.save(checkpoint, path)


def load_checkpoint(model, optimizer):
if not os.path.exists("model.pt"):
return
checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
def load_checkpoint(path):
if not os.path.exists(path):
return {}
checkpoint = torch.load(path)
return checkpoint


def test(model, device, test_loader):
Expand Down

0 comments on commit 6149b50

Please sign in to comment.