Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Update PL to 1.3.8 #531

Merged
merged 27 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mypy
  • Loading branch information
melanibe committed Jul 8, 2021
commit 0bb4a9ab8e2e0420c4b5379c47bab8fd785d3b5d
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
logging.info(f"Len encoder train dataloader {len(self.encoder_module.train_dataloader())}")
logging.info(f"Len total train dataloader {len(self.train_dataloader())}")

def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]:
def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]: # type: ignore
"""
The train dataloaders
"""
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_encoder_output_dim(pl_module: Union[pl.LightningModule, torch.nn.Module]
if dm is not None:
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
dataloader = dm.train_dataloader()
dataloader = dataloader[SSLDataModuleType.LINEAR_HEAD] if isinstance(dataloader, dict) else dataloader
dataloader = dataloader[SSLDataModuleType.LINEAR_HEAD] if isinstance(dataloader, dict) else dataloader # type: ignore
batch = iter(dataloader).next() # type: ignore
x, _ = SSLOnlineEvaluatorInnerEye.to_device(batch, device)
else:
Expand Down
9 changes: 5 additions & 4 deletions InnerEye/ML/SSL/lightning_modules/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from InnerEye.ML.SSL.lightning_modules.byol.byol_models import SiameseArm
from InnerEye.ML.SSL.lightning_modules.byol.byol_moving_average import ByolMovingAverageWeightUpdate
from InnerEye.ML.SSL.utils import SSLDataModuleType
from pytorch_lightning import Trainer

SingleBatchType = Tuple[List, T]
BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType]
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(self,

def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
# Add callback for user automatically since it's key to BYOL weight update
assert isinstance(self.trainer, Trainer)
self.weight_callback.on_before_zero_grad(self.trainer, self)

def forward(self, x: T) -> T: # type: ignore
Expand Down Expand Up @@ -110,14 +112,13 @@ def setup(self, *args: Any, **kwargs: Any) -> None:
global_batch_size = self.trainer.world_size * self.hparams.batch_size # type: ignore
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size # type: ignore

def configure_optimizers(self):
def configure_optimizers(self) -> Any:
# exclude certain parameters
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
weight_decay=self.hparams.weight_decay) # type: ignore
optimizer = Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
optimizer = Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
scheduler = LinearWarmupCosineAnnealingLR(
optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs
)
optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) # type: ignore
return [optimizer], [scheduler]

def exclude_from_wt_decay(self,
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, List, Optional

import torch
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric
from pl_bolts.models.self_supervised import SSLEvaluator
from torch.nn import functional as F

Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric
from torch import Tensor as T
from torch.nn import functional as F

Expand Down
7 changes: 4 additions & 3 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import param
import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.utilities import rank_zero_only
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -257,8 +257,8 @@ def use_sync_dist(self) -> bool:
"""
Returns True if metric logging should use sync_dist=True. This is read off from the use_ddp flag of the trainer.
"""
# For PL from version 1.2.0 on: self.trainer.accelerator_connector.use_ddp
return self.trainer.use_ddp
assert isinstance(self.trainer, Trainer)
return self.trainer.accelerator_connector.use_ddp

def on_train_epoch_start(self) -> None:
self.train_timers.reset()
Expand Down Expand Up @@ -497,6 +497,7 @@ def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
:param is_training: If True, the logged metric will be called "train/Loss". If False, the metric will
be called "val/Loss"
"""
assert isinstance(self.trainer, Trainer)
self.log_on_epoch(MetricType.LOSS, loss, is_training)
if is_training:
learning_rate = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0]
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/lightning_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn.functional as F
import torchmetrics as metrics
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric
from pytorch_lightning.metrics.functional import accuracy, auc, auroc, precision_recall_curve, roc
from torch.nn import ModuleList

Expand Down
6 changes: 4 additions & 2 deletions InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from InnerEye.ML.utils import image_util, metrics_util, model_util
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss, get_masked_model_outputs_and_labels
from pytorch_lightning import Trainer

SUBJECT_OUTPUT_PER_RANK_PREFIX = f"{SUBJECT_METRICS_FILE_NAME}.rank"

Expand Down Expand Up @@ -249,6 +250,7 @@ def on_train_start(self) -> None:
"""
# These loggers store the per-subject model outputs. They cannot be initialized in the constructor because
# the trainer object will not yet be set, and we need to get the rank from there.
assert isinstance(self.trainer, Trainer)
fixed_logger_columns = {LoggingColumns.CrossValidationSplitIndex.value: self.cross_validation_split_index}
subject_output_file = get_subject_output_file_per_rank(self.trainer.global_rank)
self.train_subject_outputs_logger = DataframeLogger(self.train_metrics_folder / subject_output_file,
Expand Down Expand Up @@ -323,7 +325,7 @@ def compute_and_log_metrics(self,
zip(_subject_ids, [prediction_target] * len(_subject_ids), _posteriors.tolist(), _labels.tolist()))
# Write a full breakdown of per-subject predictions and labels to a file. These files are local to the current
# rank in distributed training, and will be aggregated after training.
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
data_split = ModelExecutionMode.TRAIN if is_training else ModelExecutionMode.VAL
for subject, prediction_target, model_output, label in per_subject_outputs:
logger.add_record({
Expand All @@ -350,7 +352,7 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None:
# Hence, only log if anything has been accumulated.
self.log(name=prefix + metric.name + target_suffix, value=metric.compute())
metric.reset()
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
logger.flush()
super().training_or_validation_epoch_end(is_training)

Expand Down
8 changes: 4 additions & 4 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(self, container: LightningContainer):
mode="max",
save_last=False)

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any) -> None:
pl_module.log(name="epoch", value=trainer.current_epoch)
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unused: bool=None) -> None:
pl_module.log(name="epoch", value=trainer.current_epoch) # type: ignore


def create_lightning_trainer(container: LightningContainer,
Expand Down Expand Up @@ -272,8 +272,8 @@ def model_train(checkpoint_handler: CheckpointHandler,
# Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
# Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
if is_azureml_run and world_size > 1 and isinstance(lightning_model, ScalarLightning):
upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, container.outputs_folder)
upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, container.outputs_folder)
upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, container.outputs_folder) # type: ignore
upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, container.outputs_folder) # type: ignore
# DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
# We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
# all necessary properties.
Expand Down
4 changes: 2 additions & 2 deletions Tests/ML/utils/test_model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def create_model_and_store_checkpoint(config: ModelConfigBase, checkpoint_path:
trainer.model = model
# Before saving, the values for epoch and step are incremented. Save them here in such a way that we can assert
# easily later.
trainer.current_epoch = FIXED_EPOCH - 1
trainer.global_step = FIXED_GLOBAL_STEP - 1
trainer.current_epoch = FIXED_EPOCH - 1 # type: ignore
trainer.global_step = FIXED_GLOBAL_STEP - 1 # type: ignore
# In PL, it is the Trainer's responsibility to save the model. Checkpoint handling refers back to the trainer
# to get a save_func. Mimicking that here.
trainer.save_checkpoint(checkpoint_path, weights_only=weights_only)
Expand Down