Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Update PL to 1.3.8 (#531)
Browse files Browse the repository at this point in the history
* update pl

* fix one test

* our fix not needed anymore

* fix yet another test

* add new torchmetrics

* fix checkpoints

* fix some test

* fix one test more

* attempt to fix test

* Update byol code to match new pl bolts

* needed to update

* back to how it was

* update

* update

* changelog

* update regression metrics

* skip test on wind

* flake8

* forgot to update this

* mypy

* remove comment

* try to see if problem comes from sync dist flag

* few fixes

* Update expected number of subjects

* correct more

* flake8

Co-authored-by: Anton Schwaighofer <[email protected]>
  • Loading branch information
melanibe and ant0nsc committed Jul 13, 2021
1 parent 8eae655 commit 30d515b
Show file tree
Hide file tree
Showing 24 changed files with 62 additions and 87 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module on test data with partial ground truth files. (Also [522](https://github.
jobs that run in AzureML.

### Changed
- ([#531])(https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) Renamed command line option 'perform_training_set_inference' to 'inference_on_train_set'. Replaced command line option 'perform_validation_and_test_set_inference' with the pair of options 'inference_on_val_set' and 'inference_on_test_set'.
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
logging.info(f"Len encoder train dataloader {len(self.encoder_module.train_dataloader())}")
logging.info(f"Len total train dataloader {len(self.train_dataloader())}")

def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]:
def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]: # type: ignore
"""
The train dataloaders
"""
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_encoder_output_dim(pl_module: Union[pl.LightningModule, torch.nn.Module]
if dm is not None:
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
dataloader = dm.train_dataloader()
dataloader = dataloader[SSLDataModuleType.LINEAR_HEAD] if isinstance(dataloader, dict) else dataloader
dataloader = dataloader[SSLDataModuleType.LINEAR_HEAD] if isinstance(dataloader, dict) else dataloader # type: ignore
batch = iter(dataloader).next() # type: ignore
x, _ = SSLOnlineEvaluatorInnerEye.to_device(batch, device)
else:
Expand Down
3 changes: 2 additions & 1 deletion InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def create_model(self) -> LightningModule:
batch_size=self.data_module.batch_size,
learning_rate=self.l_rate,
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
warmup_epochs=10)
warmup_epochs=10,
max_epochs=self.num_epochs)
else:
raise ValueError(
f"Unknown value for ssl_training_type, should be {SSLTrainingType.SimCLR.value} or "
Expand Down
29 changes: 6 additions & 23 deletions InnerEye/ML/SSL/lightning_modules/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torch import Tensor as T
from torch.optim import Adam

from InnerEye.ML.SSL.lightning_modules.byol.byol_models import SiameseArm
from InnerEye.ML.SSL.lightning_modules.byol.byol_moving_average import ByolMovingAverageWeightUpdate
from InnerEye.ML.SSL.utils import SSLDataModuleType
from pytorch_lightning import Trainer

SingleBatchType = Tuple[List, T]
BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType]
Expand All @@ -33,6 +33,7 @@ def __init__(self,
batch_size: int,
encoder_name: str,
warmup_epochs: int,
max_epochs: int,
use_7x7_first_conv_in_resnet: bool = True,
weight_decay: float = 1e-6,
**kwargs: Any) -> None:
Expand All @@ -59,6 +60,7 @@ def __init__(self,

def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
# Add callback for user automatically since it's key to BYOL weight update
assert isinstance(self.trainer, Trainer)
self.weight_callback.on_before_zero_grad(self.trainer, self)

def forward(self, x: T) -> T: # type: ignore
Expand Down Expand Up @@ -111,31 +113,12 @@ def setup(self, *args: Any, **kwargs: Any) -> None:
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size # type: ignore

def configure_optimizers(self) -> Any:
"""
Configures the optimizer to use for training: Adam optimizer with Lars scheduling, excluding certain parameters
(batch norm and bias of convolution) from weight decay. Apply Linear Cosine Annealing schedule of learning
rate with warm-up.
"""
# TRICK 1 (Use lars + filter weights)
# exclude certain parameters
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
weight_decay=self.hparams.weight_decay) # type: ignore
optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.learning_rate)) # type: ignore

# Trick 2 (after each step)
self.hparams.warmup_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch # type: ignore
max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch

linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.hparams.warmup_epochs, # type: ignore
max_epochs=max_epochs,
warmup_start_lr=0,
eta_min=self.min_learning_rate,
)

scheduler = {'scheduler': linear_warmup_cosine_decay, 'interval': 'step', 'frequency': 1}

optimizer = Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
scheduler = LinearWarmupCosineAnnealingLR(
optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) # type: ignore
return [optimizer], [scheduler]

def exclude_from_wt_decay(self,
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, List, Optional

import torch
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric
from pl_bolts.models.self_supervised import SSLEvaluator
from torch.nn import functional as F

Expand Down Expand Up @@ -86,7 +86,7 @@ def training_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) ->

def validation_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> None: # type: ignore
loss = self.shared_step(batch, is_training=False)
self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=False)
self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=True)
for metric in self.val_metrics:
self.log(f"val/{metric.name}", metric, on_epoch=True, on_step=False)

Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric
from torch import Tensor as T
from torch.nn import functional as F

Expand Down Expand Up @@ -126,9 +126,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
if ids_linear_head not in self.visited_ids:
self.visited_ids.add(ids_linear_head)
loss = self.shared_step(batch, pl_module, is_training=True)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()

# log metrics
pl_module.log('ssl_online_evaluator/train/loss', loss)
Expand Down
7 changes: 4 additions & 3 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import param
import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.utilities import rank_zero_only
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -257,8 +257,8 @@ def use_sync_dist(self) -> bool:
"""
Returns True if metric logging should use sync_dist=True. This is read off from the use_ddp flag of the trainer.
"""
# For PL from version 1.2.0 on: self.trainer.accelerator_connector.use_ddp
return self.trainer.use_ddp
assert isinstance(self.trainer, Trainer)
return self.trainer.accelerator_connector.use_ddp

def on_train_epoch_start(self) -> None:
self.train_timers.reset()
Expand Down Expand Up @@ -497,6 +497,7 @@ def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
:param is_training: If True, the logged metric will be called "train/Loss". If False, the metric will
be called "val/Loss"
"""
assert isinstance(self.trainer, Trainer)
self.log_on_epoch(MetricType.LOSS, loss, is_training)
if is_training:
learning_rate = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0]
Expand Down
8 changes: 4 additions & 4 deletions InnerEye/ML/lightning_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import metrics
from pytorch_lightning.metrics import Metric
import torchmetrics as metrics
from torchmetrics import Metric
from pytorch_lightning.metrics.functional import accuracy, auc, auroc, precision_recall_curve, roc
from torch.nn import ModuleList

Expand Down Expand Up @@ -68,7 +68,7 @@ def has_predictions(self) -> bool:
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return len(self.y_pred) > 0 # type: ignore
return self.n_obs > 0 # type: ignore


class Accuracy05(metrics.Accuracy):
Expand All @@ -82,7 +82,7 @@ def has_predictions(self) -> bool:
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return self.total > 0 # type: ignore
return (self.total) or (self.tp + self.fp + self.tn + self.fn) > 0 # type: ignore


class AverageWithoutNan(Metric):
Expand Down
10 changes: 6 additions & 4 deletions InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from InnerEye.ML.utils.dataset_util import DatasetExample, store_and_upload_example
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss, get_masked_model_outputs_and_labels
from pytorch_lightning import Trainer

SUBJECT_OUTPUT_PER_RANK_PREFIX = f"{SUBJECT_METRICS_FILE_NAME}.rank"

Expand Down Expand Up @@ -153,8 +154,8 @@ def compute_metrics(self, cropped_sample: CroppedSample, segmentation: torch.Ten
self.log_on_epoch(name=MetricType.SUBJECT_COUNT,
value=num_subjects,
is_training=is_training,
reduce_fx=sum,
sync_dist_op=None)
reduce_fx=torch.sum,
sync_dist_op="sum")

def training_or_validation_epoch_end(self, is_training: bool) -> None:
"""
Expand Down Expand Up @@ -260,6 +261,7 @@ def on_train_start(self) -> None:
"""
# These loggers store the per-subject model outputs. They cannot be initialized in the constructor because
# the trainer object will not yet be set, and we need to get the rank from there.
assert isinstance(self.trainer, Trainer)
fixed_logger_columns = {LoggingColumns.CrossValidationSplitIndex.value: self.cross_validation_split_index}
subject_output_file = get_subject_output_file_per_rank(self.trainer.global_rank)
self.train_subject_outputs_logger = DataframeLogger(self.train_metrics_folder / subject_output_file,
Expand Down Expand Up @@ -334,7 +336,7 @@ def compute_and_log_metrics(self,
zip(_subject_ids, [prediction_target] * len(_subject_ids), _posteriors.tolist(), _labels.tolist()))
# Write a full breakdown of per-subject predictions and labels to a file. These files are local to the current
# rank in distributed training, and will be aggregated after training.
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
data_split = ModelExecutionMode.TRAIN if is_training else ModelExecutionMode.VAL
for subject, prediction_target, model_output, label in per_subject_outputs:
logger.add_record({
Expand All @@ -361,7 +363,7 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None:
# Hence, only log if anything has been accumulated.
self.log(name=prefix + metric.name + target_suffix, value=metric.compute())
metric.reset()
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
logger.flush()
super().training_or_validation_epoch_end(is_training)

Expand Down
27 changes: 8 additions & 19 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ def __init__(self, container: LightningContainer):
filename=RECOVERY_CHECKPOINT_FILE_NAME + "_{epoch}",
period=container.recovery_checkpoint_save_interval,
save_top_k=container.recovery_checkpoints_save_last_k,
mode="max")
mode="max",
save_last=False)

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any) -> None:
pl_module.log(name="epoch", value=trainer.current_epoch)
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unused: bool = None) -> None:
pl_module.log(name="epoch", value=trainer.current_epoch) # type: ignore


def create_lightning_trainer(container: LightningContainer,
Expand All @@ -121,11 +122,7 @@ def create_lightning_trainer(container: LightningContainer,
# models, this still appears to be the best way of choosing them because validation loss on the relatively small
# training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
# not for the HeadAndNeck model.
best_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
# filename=BEST_CHECKPOINT_FILE_NAME,
# monitor=f"{VALIDATION_PREFIX}{MetricType.LOSS.value}",
# save_top_k=1,
save_last=True)
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0)

# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
Expand All @@ -137,13 +134,6 @@ def create_lightning_trainer(container: LightningContainer,
# Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory).
# For unit tests, only "ddp_spawn" works
accelerator = "ddp" if effective_num_gpus > 1 else None
if effective_num_gpus > 1:
# Initialize the DDP plugin with find_unused_parameters=False by default. If True (default), it prints out
# lengthy warnings about the performance impact of find_unused_parameters
plugins = [InnerEyeDDPPlugin(num_nodes=num_nodes, sync_batchnorm=True,
find_unused_parameters=container.pl_find_unused_parameters)]
else:
plugins = []
logging.info(f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'")
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
loggers = [tensorboard_logger, AzureMLLogger()]
Expand All @@ -167,7 +157,7 @@ def create_lightning_trainer(container: LightningContainer,
benchmark = True
# If the users provides additional callbacks via get_trainer_arguments (for custom
# containers
callbacks = [best_checkpoint_callback, recovery_checkpoint_callback]
callbacks = [last_checkpoint_callback, recovery_checkpoint_callback]
if "callbacks" in kwargs:
callbacks.append(kwargs.pop("callbacks")) # type: ignore
is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
Expand All @@ -194,7 +184,6 @@ def create_lightning_trainer(container: LightningContainer,
sync_batchnorm=True,
terminate_on_nan=container.detect_anomaly,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
plugins=plugins,
**kwargs)
return trainer, storing_logger

Expand Down Expand Up @@ -283,8 +272,8 @@ def model_train(checkpoint_handler: CheckpointHandler,
# Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
# Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
if is_azureml_run and world_size > 1 and isinstance(lightning_model, ScalarLightning):
upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, container.outputs_folder)
upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, container.outputs_folder)
upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, container.outputs_folder) # type: ignore
upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, container.outputs_folder) # type: ignore
# DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
# We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
# all necessary properties.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
subject_count,loss,learning_rate,Dice/AverageAcrossStructures,Dice/spinalcord,Dice/lung_r,Dice/lung_l,VoxelCount/spinalcord,VoxelCount/lung_r,VoxelCount/lung_l,epoch,cross_validation_split_index
2.000000,0.718717,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,126273.000000,0,-1
2.000000,0.775691,0.000090,0.000000,0.000000,0.000000,0.000000,0.000000,84030.000000,0.000000,1,-1
2.000000,0.718559,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,98256.000000,0,-1
2.000000,0.792988,0.000090,0.000000,0.000000,0.000000,0.000000,0.000000,43307.000000,13992.500000,1,-1
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
subject_count,loss,Dice/AverageAcrossStructures,Dice/spinalcord,Dice/lung_r,Dice/lung_l,VoxelCount/spinalcord,VoxelCount/lung_r,VoxelCount/lung_l,epoch,cross_validation_split_index
2.000000,0.716739,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,84282.000000,0,-1
2.000000,0.716731,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,84282.000000,1,-1
2.000000,0.715468,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,89502.000000,0,-1
2.000000,0.715476,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,89502.000000,1,-1
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.011612942442297935
0.012701375409960747
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
subject_count,loss,learning_rate,Dice/AverageAcrossStructures,Dice/spinalcord,Dice/lung_r,Dice/lung_l,VoxelCount/spinalcord,VoxelCount/lung_r,VoxelCount/lung_l,epoch,cross_validation_split_index
4.000000,0.716913,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,99342.000000,0,-1
4.000000,0.773825,0.000090,0.000000,0.000000,0.000000,0.000000,181.250000,83803.250000,122.250000,1,-1
4.000000,0.753852,0.000100,0.000000,0.000000,0.000000,0.000000,0.000000,18609.000000,61179.000000,0,-1
4.000000,0.773389,0.000090,0.000000,0.000000,0.000000,0.000000,0.000000,33453.000000,24258.500000,1,-1

0 comments on commit 30d515b

Please sign in to comment.