Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix multi-node bug in PL 1.2.8 (#437)
Browse files Browse the repository at this point in the history
* Fix the bug in PL

* Add back the test

* Missing import

* CHANGELOG.md

* Fix it

* Only plugin if more than one gpu

* Only plugin if more than one gpu

* Mypy

* Mypy again
  • Loading branch information
melanibe committed Apr 16, 2021
1 parent 28404f0 commit a155946
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 27 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,18 @@ created.
- ([#432](https://github.com/microsoft/InnerEye-DeepLearning/pull/432)) Upgraded to PyTorch-Lightning 1.2.7. Add
end-to-end test for classification cross-validation. WARNING: upgrade PL version causes hanging of multi-node
training.
- ([#437])(https://github.com/microsoft/InnerEye-DeepLearning/pull/437)) Upgrade to PyTorch-Lightning 1.2.8.

### Fixed
- ([#422](https://github.com/microsoft/InnerEye-DeepLearning/pull/422)) Documentation - clarified `setting_up_aml.md`
datastore creation instructions and fixed small typos in `hello_world_model.md`
- ([#432](https://github.com/microsoft/InnerEye-DeepLearning/pull/432)) Fixed cross-validation for classification
models. Fixed multi-gpu metrics aggregation. Add end-to-end test for classification cross-validation. Add fix to bug
in ddp setting when running multi-node with 1 gpu per node.
- ([#435](https://github.com/microsoft/InnerEye-DeepLearning/pull/435)) If parameter `model` in `AzureConfig` is not set, display an error message and terminate the run.

- ([#435](https://github.com/microsoft/InnerEye-DeepLearning/pull/435)) If parameter `model` in `AzureConfig` is not
set, display an error message and terminate the run.
- ([#437](https://github.com/microsoft/InnerEye-DeepLearning/pull/437)) Fixed multi-node DDP bug in PL v1.2.8. Re-add
end-to-end test for multi-node.
### Removed

### Deprecated
Expand Down
83 changes: 81 additions & 2 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
# ------------------------------------------------------------------------------------------
import logging
import os
import subprocess
import sys
from pathlib import Path
from time import sleep
from typing import Optional, Tuple, TypeVar

import numpy as np
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from InnerEye.Azure.azure_util import RUN_CONTEXT
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, logging_section
Expand Down Expand Up @@ -63,7 +68,6 @@ def upload_output_file_as_temp(file_path: Path, outputs_folder: Path) -> None:
upload_name = TEMP_PREFIX + str(file_path.relative_to(outputs_folder))
RUN_CONTEXT.upload_file(upload_name, path_or_stream=str(file_path))


def create_lightning_trainer(config: ModelConfigBase,
resume_from_checkpoint: Optional[Path] = None,
num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]:
Expand Down Expand Up @@ -102,6 +106,7 @@ def create_lightning_trainer(config: ModelConfigBase,
# Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory).
# For unit tests, only "ddp_spawn" works
accelerator = "ddp" if num_gpus * num_nodes > 1 else None
plugins = [InnerEyeDDPPlugin(num_nodes=num_nodes, sync_batchnorm=True)] if num_gpus * num_nodes > 1 else None
logging.info(f"Using {num_gpus} GPUs with accelerator '{accelerator}'")
storing_logger = StoringLogger()
tensorboard_logger = TensorBoardLogger(save_dir=str(config.logs_folder), name="Lightning", version="")
Expand Down Expand Up @@ -140,7 +145,8 @@ def create_lightning_trainer(config: ModelConfigBase,
precision=precision,
sync_batchnorm=True,
terminate_on_nan=config.detect_anomaly,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
plugins=plugins
)
return trainer, storing_logger

Expand Down Expand Up @@ -302,3 +308,76 @@ def aggregate_and_create_subject_metrics_file(outputs_folder: Path) -> None:
# For all files but the first one, cut off the header line.
result_file.write(os.linesep + os.linesep.join(temp_file_contents.splitlines()[1:]))
result_file.close()


class InnerEyeDDPPlugin(DDPPlugin):
"""
This is a temporary fix for the broken DDP plugin in Pytorch-Lightning v1.2.8
Hopefully we can remove it once it is fixed in Pytorch-Lightning.
"""

def _call_children_scripts(self) -> None:
# This is the only line changed compared to DDPPlugin
assert self.local_rank == 0

# The code below is in the same as the original DDPPlugin
self._check_can_spawn_children()
self._has_spawned_children = True

# DDP Environment variables
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() # type: ignore
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # type: ignore

# allow the user to pass the node rank
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) # type: ignore
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) # type: ignore

path_lib = os.path.abspath

# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception:
full_path = os.path.abspath(command[0])

command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command

# the visible devices tell us how many GPUs we want to use.
# when the trainer script was called the device has already been scoped by the time
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
# but forward the GPUs selected via environment variables
if self.parallel_devices is None:
raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)")

os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices])
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"

if self.lightning_module.logger is not None:
os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version)

num_gpus = len(self.parallel_devices)
os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}"

self.interactive_ddp_procs = []

for local_rank in range(1, self.num_processes): # type: ignore
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{local_rank}"

# remove env var if global seed not set
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
del env_copy["PL_GLOBAL_SEED"]

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
cwd: Optional[str] = None
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.interactive_ddp_procs.append(proc)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)
3 changes: 1 addition & 2 deletions Tests/AfterTraining/test_after_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ def test_register_and_score_model(test_output_dirs: OutputFolderForTests) -> Non
assert_nifti_content(str(expected_segmentation_path), expected_shape, image_header, [3], np.ubyte)


# @pytest.mark.after_training_2node
@pytest.mark.skip("2 nodes training hangs with PL 1.2.7")
@pytest.mark.after_training_2node
def test_training_2nodes(test_output_dirs: OutputFolderForTests) -> None:
"""
Test if a job running on 2 nodes trains correctly.
Expand Down
40 changes: 20 additions & 20 deletions azure-pipelines/build-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,26 @@ jobs:
test_run_title: tests_after_training_ensemble_run

# Train a model on 2 nodes
# - job: Train2Nodes
# variables:
# - name: model
# value: 'BasicModel2EpochsMoreData'
# - name: tag
# value: 'Train2Nodes'
# - name: more_switches
# value: '--log_level=DEBUG --num_nodes=2'
# pool:
# vmImage: 'ubuntu-18.04'
# steps:
# - template: train_template.yml
# parameters:
# wait_for_completion: 'True'
# pytest_mark: ''
# max_run_duration: '1h'
# - template: tests_after_training.yml
# parameters:
# pytest_mark: after_training_2node
# test_run_title: tests_after_training_2node_run
- job: Train2Nodes
variables:
- name: model
value: 'BasicModel2EpochsMoreData'
- name: tag
value: 'Train2Nodes'
- name: more_switches
value: '--log_level=DEBUG --num_nodes=2'
pool:
vmImage: 'ubuntu-18.04'
steps:
- template: train_template.yml
parameters:
wait_for_completion: 'True'
pytest_mark: ''
max_run_duration: '1h'
- template: tests_after_training.yml
parameters:
pytest_mark: after_training_2node
test_run_title: tests_after_training_2node_run

# Train a classification model in cross validation mode
- job: TrainGlaucomaCV
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies:
- pytest-cov==2.10.1
- pytest-forked==1.3.0
- pytest-xdist==1.34.0
- pytorch-lightning==1.2.7
- pytorch-lightning==1.2.8
- rich==5.1.1
- rpdb==0.1.6
- scikit-image==0.17.2
Expand Down

0 comments on commit a155946

Please sign in to comment.