Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Update PL to 1.3.8 #531

Merged
merged 27 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module on test data with partial ground truth files. (Also [522](https://github.
jobs that run in AzureML.

### Changed
- ([#531])(https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) Renamed command line option 'perform_training_set_inference' to 'inference_on_train_set'. Replaced command line option 'perform_validation_and_test_set_inference' with the pair of options 'inference_on_val_set' and 'inference_on_test_set'.
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
logging.info(f"Len encoder train dataloader {len(self.encoder_module.train_dataloader())}")
logging.info(f"Len total train dataloader {len(self.train_dataloader())}")

def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]:
def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]: # type: ignore
"""
The train dataloaders
"""
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_encoder_output_dim(pl_module: Union[pl.LightningModule, torch.nn.Module]
if dm is not None:
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
dataloader = dm.train_dataloader()
dataloader = dataloader[SSLDataModuleType.LINEAR_HEAD] if isinstance(dataloader, dict) else dataloader
dataloader = dataloader[SSLDataModuleType.LINEAR_HEAD] if isinstance(dataloader, dict) else dataloader # type: ignore
batch = iter(dataloader).next() # type: ignore
x, _ = SSLOnlineEvaluatorInnerEye.to_device(batch, device)
else:
Expand Down
3 changes: 2 additions & 1 deletion InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def create_model(self) -> LightningModule:
batch_size=self.data_module.batch_size,
learning_rate=self.l_rate,
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
warmup_epochs=10)
warmup_epochs=10,
max_epochs=self.num_epochs)
else:
raise ValueError(
f"Unknown value for ssl_training_type, should be {SSLTrainingType.SimCLR.value} or "
Expand Down
29 changes: 6 additions & 23 deletions InnerEye/ML/SSL/lightning_modules/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torch import Tensor as T
from torch.optim import Adam

from InnerEye.ML.SSL.lightning_modules.byol.byol_models import SiameseArm
from InnerEye.ML.SSL.lightning_modules.byol.byol_moving_average import ByolMovingAverageWeightUpdate
from InnerEye.ML.SSL.utils import SSLDataModuleType
from pytorch_lightning import Trainer

SingleBatchType = Tuple[List, T]
BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType]
Expand All @@ -33,6 +33,7 @@ def __init__(self,
batch_size: int,
encoder_name: str,
warmup_epochs: int,
max_epochs: int,
use_7x7_first_conv_in_resnet: bool = True,
weight_decay: float = 1e-6,
**kwargs: Any) -> None:
Expand All @@ -59,6 +60,7 @@ def __init__(self,

def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
# Add callback for user automatically since it's key to BYOL weight update
assert isinstance(self.trainer, Trainer)
self.weight_callback.on_before_zero_grad(self.trainer, self)

def forward(self, x: T) -> T: # type: ignore
Expand Down Expand Up @@ -111,31 +113,12 @@ def setup(self, *args: Any, **kwargs: Any) -> None:
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size # type: ignore

def configure_optimizers(self) -> Any:
"""
Configures the optimizer to use for training: Adam optimizer with Lars scheduling, excluding certain parameters
(batch norm and bias of convolution) from weight decay. Apply Linear Cosine Annealing schedule of learning
rate with warm-up.
"""
# TRICK 1 (Use lars + filter weights)
# exclude certain parameters
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
weight_decay=self.hparams.weight_decay) # type: ignore
optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.learning_rate)) # type: ignore

# Trick 2 (after each step)
self.hparams.warmup_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch # type: ignore
max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch

linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.hparams.warmup_epochs, # type: ignore
max_epochs=max_epochs,
warmup_start_lr=0,
eta_min=self.min_learning_rate,
)

scheduler = {'scheduler': linear_warmup_cosine_decay, 'interval': 'step', 'frequency': 1}

optimizer = Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
scheduler = LinearWarmupCosineAnnealingLR(
optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) # type: ignore
return [optimizer], [scheduler]

def exclude_from_wt_decay(self,
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, List, Optional

import torch
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric
from pl_bolts.models.self_supervised import SSLEvaluator
from torch.nn import functional as F

Expand Down Expand Up @@ -86,7 +86,7 @@ def training_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) ->

def validation_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> None: # type: ignore
loss = self.shared_step(batch, is_training=False)
self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=False)
self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=True)
for metric in self.val_metrics:
self.log(f"val/{metric.name}", metric, on_epoch=True, on_step=False)

Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric
from torch import Tensor as T
from torch.nn import functional as F

Expand Down Expand Up @@ -126,9 +126,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
if ids_linear_head not in self.visited_ids:
self.visited_ids.add(ids_linear_head)
loss = self.shared_step(batch, pl_module, is_training=True)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()

# log metrics
pl_module.log('ssl_online_evaluator/train/loss', loss)
Expand Down
7 changes: 4 additions & 3 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import param
import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.utilities import rank_zero_only
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -257,8 +257,8 @@ def use_sync_dist(self) -> bool:
"""
Returns True if metric logging should use sync_dist=True. This is read off from the use_ddp flag of the trainer.
"""
# For PL from version 1.2.0 on: self.trainer.accelerator_connector.use_ddp
return self.trainer.use_ddp
assert isinstance(self.trainer, Trainer)
return self.trainer.accelerator_connector.use_ddp

def on_train_epoch_start(self) -> None:
self.train_timers.reset()
Expand Down Expand Up @@ -497,6 +497,7 @@ def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
:param is_training: If True, the logged metric will be called "train/Loss". If False, the metric will
be called "val/Loss"
"""
assert isinstance(self.trainer, Trainer)
self.log_on_epoch(MetricType.LOSS, loss, is_training)
if is_training:
learning_rate = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0]
Expand Down
8 changes: 4 additions & 4 deletions InnerEye/ML/lightning_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import metrics
from pytorch_lightning.metrics import Metric
import torchmetrics as metrics
from torchmetrics import Metric
from pytorch_lightning.metrics.functional import accuracy, auc, auroc, precision_recall_curve, roc
from torch.nn import ModuleList

Expand Down Expand Up @@ -68,7 +68,7 @@ def has_predictions(self) -> bool:
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return len(self.y_pred) > 0 # type: ignore
return self.n_obs > 0 # type: ignore


class Accuracy05(metrics.Accuracy):
Expand All @@ -82,7 +82,7 @@ def has_predictions(self) -> bool:
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return self.total > 0 # type: ignore
return (self.total) or (self.tp + self.fp + self.tn + self.fn) > 0 # type: ignore


class AverageWithoutNan(Metric):
Expand Down
10 changes: 6 additions & 4 deletions InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from InnerEye.ML.utils.dataset_util import DatasetExample, store_and_upload_example
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss, get_masked_model_outputs_and_labels
from pytorch_lightning import Trainer

SUBJECT_OUTPUT_PER_RANK_PREFIX = f"{SUBJECT_METRICS_FILE_NAME}.rank"

Expand Down Expand Up @@ -153,8 +154,8 @@ def compute_metrics(self, cropped_sample: CroppedSample, segmentation: torch.Ten
self.log_on_epoch(name=MetricType.SUBJECT_COUNT,
value=num_subjects,
is_training=is_training,
reduce_fx=sum,
sync_dist_op=None)
reduce_fx=torch.sum,
sync_dist_op="sum")

def training_or_validation_epoch_end(self, is_training: bool) -> None:
"""
Expand Down Expand Up @@ -260,6 +261,7 @@ def on_train_start(self) -> None:
"""
# These loggers store the per-subject model outputs. They cannot be initialized in the constructor because
# the trainer object will not yet be set, and we need to get the rank from there.
assert isinstance(self.trainer, Trainer)
fixed_logger_columns = {LoggingColumns.CrossValidationSplitIndex.value: self.cross_validation_split_index}
subject_output_file = get_subject_output_file_per_rank(self.trainer.global_rank)
self.train_subject_outputs_logger = DataframeLogger(self.train_metrics_folder / subject_output_file,
Expand Down Expand Up @@ -334,7 +336,7 @@ def compute_and_log_metrics(self,
zip(_subject_ids, [prediction_target] * len(_subject_ids), _posteriors.tolist(), _labels.tolist()))
# Write a full breakdown of per-subject predictions and labels to a file. These files are local to the current
# rank in distributed training, and will be aggregated after training.
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
data_split = ModelExecutionMode.TRAIN if is_training else ModelExecutionMode.VAL
for subject, prediction_target, model_output, label in per_subject_outputs:
logger.add_record({
Expand All @@ -361,7 +363,7 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None:
# Hence, only log if anything has been accumulated.
self.log(name=prefix + metric.name + target_suffix, value=metric.compute())
metric.reset()
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
logger.flush()
super().training_or_validation_epoch_end(is_training)

Expand Down
27 changes: 8 additions & 19 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ def __init__(self, container: LightningContainer):
filename=RECOVERY_CHECKPOINT_FILE_NAME + "_{epoch}",
period=container.recovery_checkpoint_save_interval,
save_top_k=container.recovery_checkpoints_save_last_k,
mode="max")
mode="max",
save_last=False)

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any) -> None:
pl_module.log(name="epoch", value=trainer.current_epoch)
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unused: bool = None) -> None:
pl_module.log(name="epoch", value=trainer.current_epoch) # type: ignore


def create_lightning_trainer(container: LightningContainer,
Expand All @@ -121,11 +122,7 @@ def create_lightning_trainer(container: LightningContainer,
# models, this still appears to be the best way of choosing them because validation loss on the relatively small
# training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
# not for the HeadAndNeck model.
best_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
# filename=BEST_CHECKPOINT_FILE_NAME,
# monitor=f"{VALIDATION_PREFIX}{MetricType.LOSS.value}",
# save_top_k=1,
save_last=True)
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0)

# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
Expand All @@ -137,13 +134,6 @@ def create_lightning_trainer(container: LightningContainer,
# Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory).
# For unit tests, only "ddp_spawn" works
accelerator = "ddp" if effective_num_gpus > 1 else None
if effective_num_gpus > 1:
# Initialize the DDP plugin with find_unused_parameters=False by default. If True (default), it prints out
# lengthy warnings about the performance impact of find_unused_parameters
plugins = [InnerEyeDDPPlugin(num_nodes=num_nodes, sync_batchnorm=True,
find_unused_parameters=container.pl_find_unused_parameters)]
else:
plugins = []
logging.info(f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'")
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
loggers = [tensorboard_logger, AzureMLLogger()]
Expand All @@ -167,7 +157,7 @@ def create_lightning_trainer(container: LightningContainer,
benchmark = True
# If the users provides additional callbacks via get_trainer_arguments (for custom
# containers
callbacks = [best_checkpoint_callback, recovery_checkpoint_callback]
callbacks = [last_checkpoint_callback, recovery_checkpoint_callback]
if "callbacks" in kwargs:
callbacks.append(kwargs.pop("callbacks")) # type: ignore
is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
Expand All @@ -194,7 +184,6 @@ def create_lightning_trainer(container: LightningContainer,
sync_batchnorm=True,
terminate_on_nan=container.detect_anomaly,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
plugins=plugins,
**kwargs)
return trainer, storing_logger

Expand Down Expand Up @@ -283,8 +272,8 @@ def model_train(checkpoint_handler: CheckpointHandler,
# Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
# Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
if is_azureml_run and world_size > 1 and isinstance(lightning_model, ScalarLightning):
upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, container.outputs_folder)
upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, container.outputs_folder)
upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, container.outputs_folder) # type: ignore
upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, container.outputs_folder) # type: ignore
# DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
# We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
# all necessary properties.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
subject_count,loss,learning_rate,Dice/AverageAcrossStructures,Dice/spinalcord,Dice/lung_r,Dice/lung_l,VoxelCount/spinalcord,VoxelCount/lung_r,VoxelCount/lung_l,epoch,cross_validation_split_index
2.000000,0.718717,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,126273.000000,0,-1
2.000000,0.775691,0.000090,0.000000,0.000000,0.000000,0.000000,0.000000,84030.000000,0.000000,1,-1
2.000000,0.718559,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,98256.000000,0,-1
2.000000,0.792988,0.000090,0.000000,0.000000,0.000000,0.000000,0.000000,43307.000000,13992.500000,1,-1
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
subject_count,loss,Dice/AverageAcrossStructures,Dice/spinalcord,Dice/lung_r,Dice/lung_l,VoxelCount/spinalcord,VoxelCount/lung_r,VoxelCount/lung_l,epoch,cross_validation_split_index
2.000000,0.716739,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,84282.000000,0,-1
2.000000,0.716731,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,84282.000000,1,-1
2.000000,0.715468,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,89502.000000,0,-1
2.000000,0.715476,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,89502.000000,1,-1
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.011612942442297935
0.012701375409960747
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
subject_count,loss,learning_rate,Dice/AverageAcrossStructures,Dice/spinalcord,Dice/lung_r,Dice/lung_l,VoxelCount/spinalcord,VoxelCount/lung_r,VoxelCount/lung_l,epoch,cross_validation_split_index
4.000000,0.716913,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,99342.000000,0,-1
4.000000,0.773825,0.000090,0.000000,0.000000,0.000000,0.000000,181.250000,83803.250000,122.250000,1,-1
4.000000,0.753852,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,18609.000000,61179.000000,0,-1
4.000000,0.773389,0.000090,0.000000,0.000000,0.000000,0.000000,0.000000,33453.000000,24258.500000,1,-1
Loading