Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Upgrade to Pytorch 1.8 #411

Merged
merged 27 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e7d098d
package upgrade
ant0nsc Mar 6, 2021
470551e
upgrade torchprof
ant0nsc Mar 7, 2021
83311fd
removing a test that now fails because AML is behind
ant0nsc Mar 7, 2021
52a3882
forking torchprof for now
ant0nsc Mar 8, 2021
eb8cf34
adding bias term
ant0nsc Mar 8, 2021
6a03041
updated pytorch lightning
ant0nsc Mar 8, 2021
676032d
updated metrics location
ant0nsc Mar 8, 2021
3c7abbd
avoiding the legacy use_ddp flag
ant0nsc Mar 8, 2021
22c6deb
trying to fix
ant0nsc Mar 8, 2021
8749465
update pillow to avoid component governance
ant0nsc Mar 9, 2021
14c4b9d
fix metrics problems
ant0nsc Mar 10, 2021
6b5584a
switch to new torchprof
ant0nsc Mar 10, 2021
a69dfd6
mypy
ant0nsc Mar 10, 2021
9e7b58a
exclude time from metrics for scalar models
ant0nsc Mar 10, 2021
e95a1ac
fix tolerance issues
ant0nsc Mar 10, 2021
63e200b
project file
ant0nsc Mar 10, 2021
30cba02
cleanup
ant0nsc Mar 10, 2021
54e5055
test fixes
ant0nsc Mar 10, 2021
68f76d4
CHANGELOG.md
ant0nsc Mar 10, 2021
34ee256
test fixes
ant0nsc Mar 10, 2021
45be974
test fixes
ant0nsc Mar 10, 2021
c0c0020
Merge branch 'main' into antonsc/pytorch18
ant0nsc Mar 11, 2021
32c144f
Merge remote-tracking branch 'origin/main' into antonsc/pytorch18
ant0nsc Mar 12, 2021
97756d2
downgrade PL to 1.1.8
ant0nsc Mar 12, 2021
f74c3c4
Merge branch 'antonsc/pytorch18' of https://github.com/microsoft/Inne…
ant0nsc Mar 12, 2021
bee0272
Merge remote-tracking branch 'origin/main' into antonsc/pytorch18
ant0nsc Mar 12, 2021
78d59fe
PR comments
ant0nsc Mar 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ nodes in AzureML. Example: Add `--num_nodes=2` to the commandline arguments to t
- ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) When registering a model, the name of the
Python execution environment is added as a tag. This tag is read when running inference, and the execution environment
is re-used.
- ([#411](https://github.com/microsoft/InnerEye-DeepLearning/pull/411)) Upgraded to PyTorch 1.8.0, PyTorch-Lightning
1.1.8 and AzureML SDK 1.23.0

### Fixed

Expand Down
18 changes: 0 additions & 18 deletions InnerEye/Azure/azure_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from typing import Any, Dict, List, Optional

from azureml.core import Dataset, Environment, Experiment, Run, ScriptRunConfig
from azureml.core.conda_dependencies import CondaDependencies
from azureml.core.datastore import Datastore
from azureml.core.runconfig import MpiConfiguration, RunConfiguration
from azureml.core.workspace import WORKSPACE_DEFAULT_BLOB_STORE_NAME
from azureml.data import FileDataset
from azureml.train.dnn import PyTorch

from InnerEye.Azure import azure_util
from InnerEye.Azure.azure_config import AzureConfig, ParserResult, SourceConfig
Expand Down Expand Up @@ -219,22 +217,6 @@ def get_or_create_dataset(azure_config: AzureConfig,
return azureml_dataset


def pytorch_version_from_conda_dependencies(conda_dependencies: CondaDependencies) -> Optional[str]:
"""
Given a CondaDependencies object, look for a spec of the form "pytorch=...", and return
whichever supported version is compatible with the value, or None if there isn't one.
"""
supported_versions = PyTorch.get_supported_versions()
for spec in conda_dependencies.conda_packages:
components = spec.split("=")
if len(components) == 2 and components[0] == "pytorch":
version = components[1]
for supported in supported_versions:
if version.startswith(supported) or supported.startswith(version):
return supported
return None


def get_or_create_python_environment(azure_config: AzureConfig,
source_config: SourceConfig,
environment_name: str = "",
Expand Down
7 changes: 3 additions & 4 deletions InnerEye/ML/dataset/scalar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from InnerEye.ML.utils.features_util import FeatureStatistics
from InnerEye.ML.utils.transforms import Compose3D, Transform3D


T = TypeVar('T', bound=ScalarDataSource)


Expand Down Expand Up @@ -167,8 +166,8 @@ def load_single_data_source(subject_rows: pd.DataFrame,
additional scalar values should be read from. THe keys should map each feature to its channels.
:param numerical_columns: The names of all columns where additional scalar values should be read from.
:param categorical_data_encoder: Encoding scheme for categorical data.
:param is_classification_dataset: If the current dataset is classification or not.
from.
:param is_classification_dataset: If True, the dataset will be used in a classification model. If False,
assume that the dataset will be used in a regression model.
:param transform_labels: a label transformation or a list of label transformation to apply to the labels.
If a list is provided, the transformations are applied in order from left to right.
:param sequence_position_numeric: Numeric position of the data source in a data sequence. Assumed to be
Expand Down Expand Up @@ -330,7 +329,7 @@ def __init__(self,
:param subject_column: The name of the column that contains the subject identifier
:param channel_column: The name of the column that contains the row identifier ("channels")
that are expected to be loaded from disk later because they are large images.
:param is_classification_dataset: If the current dataset is classification or not from.
:param is_classification_dataset: If the current dataset is classification or not.
:param categorical_data_encoder: Encoding scheme for categorical data.
"""
self.categorical_data_encoder = categorical_data_encoder
Expand Down
24 changes: 13 additions & 11 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,21 @@ def close_all_loggers(self) -> None:
self.train_epoch_metrics_logger.flush()
self.val_epoch_metrics_logger.flush()

@property
def use_sync_dist(self) -> bool:
"""
Returns True if metric logging should use sync_dist=True. This is read off from the use_ddp flag of the trainer.
"""
# For PL from version 1.2.0 on: self.trainer.accelerator_connector.use_ddp
return self.trainer.use_ddp

def on_train_epoch_start(self) -> None:
self.train_timers.reset()

def training_epoch_end(self, outputs: List[Any]) -> None:
# Write out all the metrics that have been accumulated in the StoringLogger in the previous epoch.
# Metrics for the very last epoch are written in on_train_end
self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch - 1)
self.training_or_validation_epoch_end(is_training=True)

def on_validation_epoch_start(self) -> None:
Expand Down Expand Up @@ -141,14 +152,6 @@ def validation_epoch_end(self, outputs: List[Any]) -> None:
self.random_state.restore_random_state()
self.training_or_validation_epoch_end(is_training=False)

@rank_zero_only
def on_epoch_end(self) -> None:
"""
This hook is called once per epoch, before on_train_epoch_end. Use it to write out all the metrics
that have been accumulated in the StoringLogger in the previous epoch.
"""
self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch - 1)

@rank_zero_only
def on_train_end(self) -> None:
"""
Expand All @@ -157,6 +160,7 @@ def on_train_end(self) -> None:
"""
self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch)

@rank_zero_only
def read_epoch_results_from_logger_and_store(self, epoch: int) -> None:
"""
Reads the metrics for the previous epoch from the StoringLogger, and writes them to disk, broken down by
Expand All @@ -167,8 +171,6 @@ def read_epoch_results_from_logger_and_store(self, epoch: int) -> None:
for is_training, prefix in [(True, TRAIN_PREFIX), (False, VALIDATION_PREFIX)]:
metrics = self.storing_logger.extract_by_prefix(epoch, prefix)
self.store_epoch_results(metrics, epoch, is_training)
else:
print(f"Skipping, no results for {epoch}")

@rank_zero_only
def training_or_validation_epoch_end(self, is_training: bool) -> None:
Expand Down Expand Up @@ -234,7 +236,7 @@ def log_on_epoch(self,
if isinstance(value, numbers.Number):
value = torch.tensor(value, dtype=torch.float, device=self.device)
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
sync_dist = self.use_ddp if sync_dist_override is None else sync_dist_override
sync_dist = self.use_sync_dist if sync_dist_override is None else sync_dist_override
self.log(prefix + metric_name, value,
sync_dist=sync_dist,
on_step=False, on_epoch=True,
Expand Down
22 changes: 16 additions & 6 deletions InnerEye/ML/lightning_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import torch.nn.functional as F
from pytorch_lightning import metrics
from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics.functional import roc
from pytorch_lightning.metrics.functional.classification import accuracy, auc, auroc, precision_recall_curve
from pytorch_lightning.metrics.functional import accuracy, auc, auroc, precision_recall_curve, roc
from torch.nn import ModuleList

from InnerEye.Common.metrics_constants import AVERAGE_DICE_SUFFIX, MetricType, TRAIN_PREFIX, VALIDATION_PREFIX
Expand Down Expand Up @@ -162,6 +161,9 @@ def _get_metrics_at_optimal_cutoff(self) -> Tuple[torch.Tensor, torch.Tensor, to
if torch.unique(targets).numel() == 1:
return torch.tensor(np.nan), torch.tensor(np.nan), torch.tensor(np.nan), torch.tensor(np.nan)
fpr, tpr, thresholds = roc(preds, targets)
assert isinstance(fpr, torch.Tensor)
assert isinstance(tpr, torch.Tensor)
assert isinstance(thresholds, torch.Tensor)
optimal_idx = torch.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
acc = accuracy(preds > optimal_threshold, targets)
Expand Down Expand Up @@ -246,21 +248,22 @@ def compute(self) -> torch.Tensor:
if torch.unique(targets).numel() == 1:
return torch.tensor(np.nan)
prec, recall, _ = precision_recall_curve(preds, targets)
return auc(recall, prec)
return auc(recall, prec) # type: ignore


class BinaryCrossEntropyWithLogits(ScalarMetricsBase):
"""
Computes the cross entropy for binary classification.
This metric must be computed off the output logits
This metric must be computed off the model output logits.
"""

def __init__(self) -> None:
super().__init__(name=MetricType.CROSS_ENTROPY.value, compute_from_logits=True)

def compute(self) -> torch.Tensor:
preds, targets = self._get_preds_and_targets()
return F.binary_cross_entropy_with_logits(input=preds, target=targets)
# All classification metrics work with integer targets, but this one does not. Convert to float.
return F.binary_cross_entropy_with_logits(input=preds, target=targets.to(dtype=preds.dtype))


class MetricForMultipleStructures(torch.nn.Module):
Expand Down Expand Up @@ -320,5 +323,12 @@ def compute_all(self) -> Iterator[Tuple[str, torch.Tensor]]:
of (metric name, metric value) tuples. This will automatically also call .reset() on the metrics.
The first returned metric is the average across all structures, then come the per-structure values.
"""
for d in iter(self):
for d in self:
yield d.name, d.compute() # type: ignore

def reset(self) -> None:
"""
Calls the .reset() method on all the metrics that the present object holds.
"""
for d in self:
d.reset()
20 changes: 12 additions & 8 deletions InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,14 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None:
"""
Writes all training or validation metrics that were aggregated over the epoch to the loggers.
"""
dice = list((self.train_dice if is_training else self.val_dice).compute_all())
for name, value in dice:
dice = self.train_dice if is_training else self.val_dice
for name, value in dice.compute_all():
self.log(name, value)
voxel_count = list((self.train_voxels if is_training else self.val_voxels).compute_all())
for name, value in voxel_count:
dice.reset()
voxel_count = self.train_voxels if is_training else self.val_voxels
for name, value in voxel_count.compute_all():
self.log(name, value)
voxel_count.reset()
super().training_or_validation_epoch_end(is_training=is_training)


Expand Down Expand Up @@ -303,8 +305,9 @@ def compute_and_log_metrics(self,
if masked is not None:
_logits = masked.model_outputs.data
_posteriors = self.logits_to_posterior(_logits)
# Image encoders already prepare images in float16, but the labels are not yet in that dtype
_labels = masked.labels.data.to(dtype=_posteriors.dtype)
# Classification metrics expect labels as integers, but they are float throughout the rest of the code
labels_dtype = torch.int if self.is_classification_model else _posteriors.dtype
_labels = masked.labels.data.to(dtype=labels_dtype)
_subject_ids = masked.subject_ids
assert _subject_ids is not None
for metric in metric_list:
Expand Down Expand Up @@ -339,13 +342,14 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None:
for metric in metric_list:
if metric.has_predictions:
# Sequence models can have no predictions at all for particular positions, depending on the data.
# Hence, only log if anything really has been accumula
# Hence, only log if anything has been accumulated.
self.log(name=prefix + metric.name + target_suffix, value=metric.compute())
metric.reset()
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger
logger.flush()
super().training_or_validation_epoch_end(is_training)

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: # type: ignore
"""
For sequence models, transfer the nested lists of items to the given GPU device.
For all other models, this relies on the superclass to move the batch of data to the GPU.
Expand Down
5 changes: 4 additions & 1 deletion InnerEye/ML/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,16 @@ def store_epoch_metrics(metrics: DictStrFloat,
epoch: int,
file_logger: DataframeLogger) -> None:
"""
Writes all metrics into a CSV file, with an additional columns for epoch number.
Writes all metrics (apart from ones that measure run time) into a CSV file,
with an additional columns for epoch number.
:param file_logger: An instance of DataframeLogger, for logging results to csv.
:param epoch: The epoch corresponding to the results.
:param metrics: The metrics of the specified epoch, averaged along its batches.
"""
logger_row = {}
for key, value in metrics.items():
if key == MetricType.SECONDS_PER_BATCH.value or key == MetricType.SECONDS_PER_EPOCH.value:
continue
if key in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys():
logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[key].value] = value
else:
Expand Down
6 changes: 4 additions & 2 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from InnerEye.Common.metrics_constants import TRAIN_PREFIX, VALIDATION_PREFIX
from InnerEye.Common.resource_monitor import ResourceMonitor
from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, cleanup_checkpoint_folder
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.deep_learning_config import VISUALIZATION_FOLDER
from InnerEye.ML.lightning_base import TrainingAndValidationDataLightning
from InnerEye.ML.lightning_helpers import create_lightning_model
Expand Down Expand Up @@ -186,8 +187,9 @@ def model_train(config: ModelConfigBase,
ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization")
# Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
# want training to depend on how many patients we visualized, and hence set the random seed again right after.
with logging_section("Visualizing the effect of sampling random crops for training"):
visualize_random_crops_for_dataset(config)
if isinstance(config, SegmentationModelBase):
with logging_section("Visualizing the effect of sampling random crops for training"):
visualize_random_crops_for_dataset(config)

# Print out a detailed breakdown of layers, memory consumption and time.
generate_and_print_model_summary(config, lightning_model.model)
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/models/layers/weight_standardization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def standardize(weights: torch.Tensor) -> torch.Tensor:

def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
standardized_weights = WeightStandardizedConv2d.standardize(self.weight)
return self._conv_forward(input, standardized_weights) # type: ignore
return self._conv_forward(input, standardized_weights, bias=None) # type: ignore
7 changes: 6 additions & 1 deletion InnerEye/ML/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,13 @@ def run_in_situ(self) -> None:
print_exception(ex, "Unable to run PyTest.")
error_messages.append(f"Unable to run PyTest: {ex}")
else:
try:
# Set environment variables for multi-node training if needed.
# In particular, the multi-node environment variables should NOT be set in single node
# training, otherwise this might lead to errors with the c10 distributed backend
# (https://github.com/microsoft/InnerEye-DeepLearning/issues/395)
if self.azure_config.num_nodes > 1:
set_environment_variables_for_multi_node()
try:
logging_to_file(self.model_config.logs_folder / LOG_FILE_NAME)
try:
self.create_ml_runner().run()
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/visualizers/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy as np
import torch
import torchprof
from torch.utils.hooks import RemovableHandle
from torchprof.profile import Profile

from InnerEye.Common.common_util import logging_only_to_file
from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH
Expand Down Expand Up @@ -188,7 +188,7 @@ def print_summary() -> None:

# Register the forward-pass hooks, profile the model, and restore its state
self.model.apply(self._register_hook)
with torchprof.Profile(self.model, use_cuda=self.use_gpu) as prof:
with Profile(self.model, use_cuda=self.use_gpu) as prof:
forward_preserve_state(self.model, input_tensors) # type: ignore

# Log the model summary: tensor shapes, num of parameters, memory requirement, and forward pass time
Expand Down
6 changes: 1 addition & 5 deletions InnerEye/ML/visualizers/patch_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from InnerEye.ML.dataset.cropping_dataset import CroppingDataset
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
from InnerEye.ML.dataset.sample import Sample
from InnerEye.ML.deep_learning_config import DeepLearningConfig
from InnerEye.ML.plotting import resize_and_save, scan_with_transparent_overlay
from InnerEye.ML.utils import augmentation, io_util, ml_util
from InnerEye.ML.utils.config_util import ModelConfigLoader
Expand Down Expand Up @@ -108,17 +107,14 @@ def visualize_random_crops(sample: Sample,
return heatmap


def visualize_random_crops_for_dataset(config: DeepLearningConfig,
output_folder: Optional[Path] = None) -> None:
def visualize_random_crops_for_dataset(config: SegmentationModelBase, output_folder: Optional[Path] = None) -> None:
"""
For segmentation models only: This function generates visualizations of the effect of sampling random patches
for training. Visualizations are stored in both Nifti format, and as 3 PNG thumbnail files, in the output folder.
:param config: The model configuration.
:param output_folder: The folder in which the visualizations should be written. If not provided, use a subfolder
"patch_sampling" in the models's default output folder
"""
if not isinstance(config, SegmentationModelBase):
return
dataset_splits = config.get_dataset_splits()
# Load a sample using the full image data loader
full_image_dataset = FullImageDataset(config, dataset_splits.train)
Expand Down
Loading