Skip to content

Commit

Permalink
Refactor arguments of ElasticDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
workingloong committed May 10, 2023
1 parent 7a166b3 commit de0b5a8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
4 changes: 2 additions & 2 deletions dlrover/examples/torch_mnist_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ spec:
--max_restarts=3 \
--rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT:2222 \
model_zoo/pytorch/mnist_cnn.py \
--training_data /data/mnist_png/training/elastic_ds.txt \
--validation_data /data/mnist_png/testing/elastic_ds.txt"
--training_data /data/mnist_png/training/ \
--validation_data /data/mnist_png/testing/ "
dlrover-master:
template:
spec:
Expand Down
4 changes: 2 additions & 2 deletions dlrover/examples/torch_mnist_master_backend_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ spec:
&& torchrun --nnodes=1:$WORKER_NUM --nproc_per_node=1
--max_restarts=3 --rdzv_backend=dlrover-master \
model_zoo/pytorch/mnist_cnn.py \
--training_data /data/mnist_png/training/elastic_ds.txt \
--validation_data /data/mnist_png/testing/elastic_ds.txt"
--training_data /data/mnist_png/training/ \
--validation_data /data/mnist_png/testing/"
resources:
limits:
cpu: "1"
Expand Down
23 changes: 16 additions & 7 deletions dlrover/trainer/torch/elastic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import os
import time
from abc import ABCMeta, abstractmethod

import torch.distributed as dist
Expand All @@ -34,27 +35,35 @@ def read_txt(path):


class ElasticDataset(Dataset, metaclass=ABCMeta):
def __init__(self, path, batch_size, epochs, shuffle, checkpoint_path=""):
def __init__(
self,
dataset_size,
batch_size,
epochs,
shuffle,
checkpoint_path="",
name=None,
):
"""Using ElasticDataset, the node can read samples without
duplicates with other nodes in an epoch. DLRover master
will dispatch the index of sample in a dataset to one node.
Args:
path: str, the path of dataset meta file. For example, if the image
is stored in a folder. The meta file should be a
text file where each line is the absolute path of a image.
dataset_size: the number of samples in the dataset.
batch_size: int, the size of batch samples to compute gradients
in a trainer process.
epochs: int, the number of epoch.
shuffle: bool, whether to shuffle samples in the dataset.
checkpoint_path: the path to save the checkpoint of shards
int the dataset.
name: str, the name of dataset.
"""
self.lines = read_txt(path)
self.dataset_size = len(self.lines)
self.dataset_size = dataset_size
self._checkpoint_path = checkpoint_path
if not name:
name = "dlrover-ds-" + str(time.time())
self._shard_client = IndexShardingClient(
dataset_name=path,
dataset_name=name,
batch_size=batch_size,
num_epochs=epochs,
dataset_size=self.dataset_size,
Expand Down
22 changes: 16 additions & 6 deletions model_zoo/pytorch/mnist_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,42 @@
from dlrover.trainer.torch.elastic_dataset import ElasticDataset


def build_data_meta(folder):
dataset_meta = []
for d in os.listdir(folder):
dir_path = os.path.join(folder, d)
if os.path.isdir(dir_path):
for f in os.listdir(dir_path):
if f.endswith(".png"):
file_path = os.path.join(dir_path, f)
dataset_meta.append([file_path, d])
return dataset_meta


class ElasticMnistDataset(ElasticDataset):
def __init__(self, path, batch_size, epochs, shuffle, checkpoint_path):
"""The dataset supports elastic training.
Args:
path: str, the path of dataset meta file. For example, if the image
is stored in a folder. The meta file should be a
text file where each line is the absolute path of a image.
dataset_size: the number of samples in the dataset.
batch_size: int, the size of batch samples to compute gradients
in a trainer process.
epochs: int, the number of epoch.
shuffle: bool, whether to shuffle samples in the dataset.
checkpoint_path: the path to save the checkpoint of shards
int the dataset.
"""
self.data_meta = build_data_meta(path)
super(ElasticMnistDataset, self).__init__(
path,
len(self.data_meta),
batch_size,
epochs,
shuffle,
checkpoint_path,
)

def read_sample(self, index):
image_meta = self.lines[index]
image_path, label = image_meta.split(",")
image_path, label = self.data_meta[index]
image = cv2.imread(image_path)
image = np.array(image / 255.0, np.float32)
image = image.reshape(3, 28, 28)
Expand Down

0 comments on commit de0b5a8

Please sign in to comment.