Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs][train]Make Train example titles, heading more consistent #39606

Merged
merged 18 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
copy editing lightning-mnist and torch-fashion-mnist examples
Signed-off-by: angelinalg <[email protected]>
  • Loading branch information
angelinalg committed Sep 7, 2023
commit a4ad14f84a4aa21d2be09413d71c30b8955bba32
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"\n",
"# Train a Pytorch Lightning Image Classifier\n",
"\n",
"This example introduces how to train a Pytorch Lightning Module using Ray Train {class}`TorchTrainer <ray.train.torch.TorchTrainer>`. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.\n"
"This example introduces how to train a Pytorch Lightning Module using Ray Train {class}`TorchTrainer <ray.train.torch.TorchTrainer>`. It demonstrates how to train a basic neural network on the MNIST dataset with distributed data parallelism.\n"
]
},
{
Expand Down Expand Up @@ -49,9 +49,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare Dataset and Module\n",
"## Prepare a dataset and module\n",
"\n",
"The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can keep using them without any changes for the Ray AIR LightningTrainer. "
"The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can continue using them without any changes for the Ray Train LightningTrainer. "
]
},
{
Expand Down Expand Up @@ -183,15 +183,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the Training Loop\n",
"## Define the training loop\n",
"\n",
"Here we define a training loop for each worker. Compare with the original PyTorch Lightning code, there are 3 main differences:\n",
"This code defines a training loop for each worker. Comparing the training loop with the original PyTorch Lightning code, there are 3 main differences:\n",
"\n",
"- Distributed strategy: Use {class}`RayDDPStrategy <ray.train.lightning.RayDDPStrategy>`.\n",
"- Cluster environment: Use {class}`RayLightningEnvironment <ray.train.lightning.RayLightningEnvironment>`.\n",
"- Parallel devices: Always sets to `devices=\"auto\"` to use all available devices configured by ``TorchTrainer``.\n",
"\n",
"Please refer to {ref}`Getting Started with PyTorch Lightning <train-pytorch-lightning>`.\n",
"See {ref}`Getting Started with PyTorch Lightning <train-pytorch-lightning>` for more information.\n",
"\n",
"\n",
"For checkpoint reportining, Ray Train provides a minimal {class}`RayTrainReportCallback <ray.train.lightning.RayTrainReportCallback>` that reports metrics and checkpoint on each train epoch end. For more complex checkpoint logic, please implement custom callbacks as described in {ref}`Saving and Loading Checkpoint <train-checkpointing>` user guide."
Expand Down
14 changes: 7 additions & 7 deletions python/ray/train/examples/pytorch/torch_fashion_mnist_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ def get_dataloaders(batch_size):
transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

with FileLock(os.path.expanduser("~/data.lock")):
# Download training data from open datasets.
# Download training data from open datasets
angelinalg marked this conversation as resolved.
Show resolved Hide resolved
training_data = datasets.FashionMNIST(
root="~/data",
train=True,
download=True,
transform=transform,
)

# Download test data from open datasets.
# Download test data from open datasets
angelinalg marked this conversation as resolved.
Show resolved Hide resolved
test_data = datasets.FashionMNIST(
root="~/data",
train=False,
download=True,
transform=transform,
)

# Create data loaders.
# Create data loaders
angelinalg marked this conversation as resolved.
Show resolved Hide resolved
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

Expand Down Expand Up @@ -69,7 +69,7 @@ def train_func_per_worker(config: Dict):
epochs = config["epochs"]
batch_size = config["batch_size_per_worker"]

# Get dataloaders inside worker training function
# Get dataloaders inside the worker training function
angelinalg marked this conversation as resolved.
Show resolved Hide resolved
train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)

# [1] Prepare Dataloader for distributed training
Expand All @@ -81,7 +81,7 @@ def train_func_per_worker(config: Dict):
model = NeuralNetwork()

# [2] Prepare and wrap your model with DistributedDataParallel
# Move the model the correct GPU/CPU device
# Move the model to the correct GPU/CPU device
angelinalg marked this conversation as resolved.
Show resolved Hide resolved
# ============================================================
model = ray.train.torch.prepare_model(model)

Expand Down Expand Up @@ -137,9 +137,9 @@ def train_fashion_mnist(num_workers=2, use_gpu=False):
scaling_config=scaling_config,
)

# [4] Start Distributed Training
# [4] Start distributed training
angelinalg marked this conversation as resolved.
Show resolved Hide resolved
# Run `train_func_per_worker` on all workers
# =============================================
# ==========================================
angelinalg marked this conversation as resolved.
Show resolved Hide resolved
result = trainer.fit()
print(f"Training result: {result}")

Expand Down