Skip to content

Commit

Permalink
Simplify mnist example. (intelligent-machine-learning#929)
Browse files Browse the repository at this point in the history
* Update setup.py

* Simplify mnist example

* Format codes.

* Fix by comments.

* Format codes.

* Fix by comments.
  • Loading branch information
youxingling committed Jan 5, 2024
1 parent 2acfe82 commit 563e7d4
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 34 deletions.
12 changes: 0 additions & 12 deletions examples/__init__.py

This file was deleted.

11 changes: 10 additions & 1 deletion examples/pytorch/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ with MNIST dataset.

## Prepare Data

You can directly use `datasets.MNIST()` to create a dataset, or you can manually download it as follows:

- Download the dataset from [Kaggle MNIST Dataset](https://www.kaggle.com/datasets/hojjatk/mnist-dataset).
- Untar the dataset into a directory like `data/mnist_png`.

Expand All @@ -28,7 +30,14 @@ pip install dlrover -U
Then, we can use `dlrover-run` to start the training by

```bash
dlrover-run --standalone --nproc_per_node=${GPU_NUM} \
dlrover-run --nproc_per_node=${GPU_NUM} \
examples/pytorch/mnist/cnn_train.py --num_epochs 5
```

or

```bash
dlrover-run --nproc_per_node=${GPU_NUM} \
examples/pytorch/mnist/cnn_train.py --num_epochs 5 \
--training_data data/mnist_png/training/ \
--validation_data data/mnist_png/testing/
Expand Down
44 changes: 29 additions & 15 deletions examples/pytorch/mnist/cnn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.distributed.elastic.multiprocessing.errors import record
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets, transforms

from dlrover.trainer.torch.elastic.dataloader import ElasticDataLoader
from dlrover.trainer.torch.elastic.sampler import ElasticDistributedSampler
Expand All @@ -47,7 +46,7 @@ def log_rank0(msg):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
Expand Down Expand Up @@ -98,22 +97,37 @@ def train(args):
"""
setup()

train_data = torchvision.datasets.ImageFolder(
root=args.training_data,
transform=transforms.ToTensor(),
)
if args.training_data:
train_dataset = datasets.ImageFolder(
root=args.training_data, transform=transforms.ToTensor()
)
else:
train_dataset = datasets.MNIST(
"./data",
train=True,
transform=transforms.ToTensor(),
download=True,
)

# Setup sampler for elastic training.
sampler = ElasticDistributedSampler(dataset=train_data)
sampler = ElasticDistributedSampler(dataset=train_dataset)
train_loader = ElasticDataLoader(
dataset=train_data,
dataset=train_dataset,
batch_size=args.batch_size,
sampler=sampler,
)

test_dataset = torchvision.datasets.ImageFolder(
root=args.validation_data,
transform=torchvision.transforms.ToTensor(),
)
if args.validation_data:
test_dataset = datasets.ImageFolder(
root=args.validation_data, transform=transforms.ToTensor()
)
else:
test_dataset = datasets.MNIST(
"./data",
train=False,
transform=transforms.ToTensor(),
download=True,
)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size)

model = Net()
Expand Down Expand Up @@ -279,9 +293,9 @@ def arg_parser():
default=False,
help="disable CUDA training",
)
parser.add_argument("--training_data", type=str, required=True)
parser.add_argument("--training_data", type=str, required=False)
parser.add_argument(
"--validation_data", type=str, default="", required=True
"--validation_data", type=str, default="", required=False
)
parser.add_argument("--save_model", action="store_true", required=False)
return parser
Expand Down
1 change: 1 addition & 0 deletions scripts/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ make -f dlrover/Makefile
# Create dlrover package
echo "Building the wheel for dlrover."
rm -rf ./build/lib

python setup.py --quiet bdist_wheel
8 changes: 2 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@
url="https://github.com/intelligent-machine-learning/dlrover",
install_requires=install_requires,
extras_require=extra_require,
python_requires=">=3.5",
packages=find_packages(
exclude=[
"model_zoo*",
]
),
python_requires=">=3.8",
packages=find_packages(),
package_data={
"": [
"proto/*",
Expand Down

0 comments on commit 563e7d4

Please sign in to comment.