From cf511288d94152d7dba27a0ba3fc6c4ab222d9a2 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 22 Sep 2021 19:55:59 +0100 Subject: [PATCH 01/41] instructions --- InnerEye/Common/fixed_paths.py | 3 ++- docs/environment.md | 23 +++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/InnerEye/Common/fixed_paths.py b/InnerEye/Common/fixed_paths.py index cc0b5c922..dc0ee2a56 100755 --- a/InnerEye/Common/fixed_paths.py +++ b/InnerEye/Common/fixed_paths.py @@ -107,7 +107,8 @@ def add_submodules_to_path() -> None: innereye_root = repository_root_directory() folders_to_add = [(innereye_root, "InnerEye"), (innereye_root / "fastMRI", "fastmri"), - (innereye_root / "hi-ml" / "src", "health")] + (innereye_root / "hi-ml" / "hi-ml-azure" / "src", "health"), + (innereye_root / "hi-ml" / "hi-ml" / "src", "health")] for (folder, subfolder_that_must_exist) in folders_to_add: if (folder / subfolder_that_must_exist).is_dir(): folder_str = str(folder) diff --git a/docs/environment.md b/docs/environment.md index 397518357..28ca6faa5 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -197,18 +197,17 @@ as a submodule, rather than a package from pypi. Any change to the package will and that costs 20min per run. * In the repository root, run `git submodule add https://github.com/microsoft/hi-ml` -* In PyCharm's project browser, mark the folder `hi-ml/src` as Sources Root -* Remove the entry for the `hi-ml` package from `environment.yml` -* Modify the start of `InnerEye/ML/runner.py` to look like this: -```python -print(f"Starting InnerEye runner at {sys.argv[0]}") -innereye_root = Path(__file__).absolute().parent.parent.parent -if (innereye_root / "InnerEye").is_dir(): - innereye_root_str = str(innereye_root) - if innereye_root_str not in sys.path: - print(f"Adding InnerEye folder to sys.path: {innereye_root_str}") - sys.path.insert(0, innereye_root_str) - sys.path.append(str(innereye_root / "hi-ml" / "src")) +* In PyCharm's project browser, mark the folders `hi-ml/hi-ml/src` and `hi-ml/hi-ml-azure/src` as Sources Root +* Remove the entry for the `hi-ml` and `hi-ml-azure` packages from `environment.yml` +* There is already code in `InnerEye.Common.fixed_paths.add_submodules_to_path` that will pick up the submodules and + add them to `sys.path`. + +Once you are done testing your changes, remove the entry for `hi-ml` from `.gitmodules` and execute these steps +from the repository root: +```shell +git submodule deinit -f hi-ml +rmdir hi-ml +rm -rf .git/modules/hi-ml ``` Alternatively, you can consume a developer version of `hi-ml` from `test.pypi`: From b6201b53183e1d946b70f9efc2cda790aac76787 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 13 Oct 2021 17:05:07 +0100 Subject: [PATCH 02/41] Moving loading time out into a callback --- .../SSL/lightning_containers/ssl_container.py | 2 +- InnerEye/ML/deep_learning_config.py | 1 - InnerEye/ML/lightning_base.py | 256 ++++++++++-------- InnerEye/ML/lightning_models.py | 2 - InnerEye/ML/model_training.py | 41 +-- 5 files changed, 175 insertions(+), 127 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index d3f934042..a103fc630 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -100,7 +100,7 @@ class SSLContainer(LightningContainer): def setup(self) -> None: from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer - self.total_num_gpus = self.num_gpus_per_node * self.num_nodes + self.total_num_gpus = self.num_gpus_per_node() * self.num_nodes self._load_config() # If you're using the same data for training and linear head, allow the user to specify the dataset only # once. Or if you are doing just finetuning of linear head, the user should be able to specify dataset via diff --git a/InnerEye/ML/deep_learning_config.py b/InnerEye/ML/deep_learning_config.py index 4a8161045..b7ceb3ce1 100644 --- a/InnerEye/ML/deep_learning_config.py +++ b/InnerEye/ML/deep_learning_config.py @@ -602,7 +602,6 @@ def use_gpu(self) -> bool: from InnerEye.ML.utils.ml_util import is_gpu_available return is_gpu_available() - @property def num_gpus_per_node(self) -> int: """ Computes the number of gpus to use for each node: either the number of gpus available on the device diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 878379c09..33c7eee88 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -9,7 +9,7 @@ import param import torch -from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities import rank_zero_only from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler @@ -20,6 +20,7 @@ from InnerEye.Common.type_annotations import DictStrFloat from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.config import SegmentationModelBase +from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths from InnerEye.ML.deep_learning_config import DatasetParams, DeepLearningConfig, OutputParams, TrainerParams, \ WorkflowParams from InnerEye.ML.lightning_container import LightningContainer @@ -28,13 +29,13 @@ from InnerEye.ML.metrics_dict import DataframeLogger from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.utils import model_util +from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER from InnerEye.ML.utils.device_aware_module import DeviceAwareModule from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp from InnerEye.ML.utils.ml_util import RandomStateSnapshot, set_random_seed, validate_dataset_paths from InnerEye.ML.utils.model_util import generate_and_print_model_summary from InnerEye.ML.visualizers.patch_sampling import visualize_random_crops_for_dataset -from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER -from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths + class TrainAndValDataLightning(LightningDataModule): """ @@ -200,6 +201,152 @@ def load_checkpoint_and_modify(self, path_to_checkpoint: Path) -> Dict[str, Any] return self.config.load_checkpoint_and_modify(path_to_checkpoint=path_to_checkpoint) +class BatchTimeCallback(Callback): + """ + This class provides tools to measure batch loading time and other diagnostic information. + """ + + def __init__(self) -> None: + # Timers for monitoring data loading time + self.train_timers = EpochTimers() + self.val_timers = EpochTimers() + + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.module = pl_module + self.logger = trainer.logger + + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.train_timers.reset() + + def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure + that any randomization when loading validation data is consistent during training. In particular, this ensures + that drawing random patches for segmentation model training is giving a validation set that does not fluctuate. + """ + self.val_timers.reset() + # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training + # is done for this epoch, even though the on_training_epoch hook has not yet been called. + self.train_timers.epoch_end() + + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + This is a hook called at the end of a training or validation epoch. In here, we can still write + metrics to a logger. + """ + # In validation epochs, mark that it has been completed. Training epochs are marked completed already + # at the start of the validation epoch. + self.val_timers.epoch_end() + # Write all IO stats here, so that the order on the console is Train start, train end, val start, val end. + self.write_and_log_epoch_time(is_training=True) + self.write_and_log_epoch_time(is_training=False) + + def on_train_batch_start(self, + trainer: Trainer, + pl_module: LightningModule, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self.batch_start(batch_idx=batch_idx, is_training=True) + + def on_validation_batch_start(self, + trainer: Trainer, + pl_module: LightningModule, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self.batch_start(batch_idx=batch_idx, is_training=False) + + def on_train_batch_end(self, + trainer: Trainer, + pl_module: LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self.batch_end(is_training=True) + + def on_validation_batch_end(self, + trainer: Trainer, + pl_module: LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self.batch_end(is_training=False) + + def write_and_log_epoch_time(self, is_training: bool) -> None: + """ + Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the + time per epoch. + :param is_training: If True, show and log the data for the training epoch. If False, use the data for the + validation epoch. + """ + timers = self.get_timers(is_training=is_training) + epoch_time_seconds = timers.total_epoch_time + status = "training" if is_training else "validation" + logging.info(f"Epoch {self.module.current_epoch} {status} took {epoch_time_seconds:0.2f}sec, of which waiting " + f"for data took {timers.total_load_time:0.2f} sec total.") + if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch: + logging.warning("The dataloaders were not fast enough to always supply the next batch in less than " + f"{MAX_ITEM_LOAD_TIME_SEC}sec.") + logging.warning( + f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load " + f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.") + # This metric is only written at rank zero, and hence must not be synchronized across workers. If attempted, + # training will get stuck. + self.log_metric(MetricType.SECONDS_PER_EPOCH, value=epoch_time_seconds, is_training=is_training) + + def log_metric(self, metric_type: MetricType, value: float, is_training: bool) -> None: + # Metrics are only written at rank 0, and hence must not be synchronized. Trying to synchronize will + # block training. + prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX + self.module.log(name=prefix + metric_type.value, value=value, + on_step=False, on_epoch=True, sync_dist=False) + + @rank_zero_only + def batch_start(self, batch_idx: int, is_training: bool) -> None: + """ + Shared code to keep track of IO-related metrics when loading a minibatch. This is only done on rank zero. + :param batch_idx: The index of the current minibatch. + :param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from + `on_validation_batch_start`. + :return: + """ + timers = self.get_timers(is_training=is_training) + epoch = self.module.current_epoch + message_prefix = f"Epoch {epoch} {'training' if is_training else 'validation'}" + timers.batch_start(batch_index=batch_idx, epoch=epoch, message_prefix=message_prefix) + + @rank_zero_only + def batch_end(self, is_training: bool) -> None: + """ + Shared code to keep track of IO-related metrics when loading a minibatch. + :param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from + `on_validation_batch_end`. + """ + timers = self.get_timers(is_training=is_training) + batch_time = timers.batch_end() + self.log_metric(MetricType.SECONDS_PER_BATCH, value=batch_time, is_training=is_training) + + def get_timers(self, is_training: bool) -> EpochTimers: + """ + Gets the object that holds all IO-related metrics and timers, for either the validation or the training epoch. + """ + return self.train_timers if is_training else self.val_timers + + def reset_timers(self) -> None: + """ + Resets all timers and counters for IO-related metrics, for both the validation and the training epoch. + """ + self.train_timers.reset() + self.val_timers.reset() + + class InnerEyeLightning(LightningModule): """ The base class for all InnerEye models for training in PyTorch Lightning. The base class handles all shared @@ -220,9 +367,6 @@ def __init__(self, config: DeepLearningConfig, *args: Any, **kwargs: Any) -> Non self.l_rate_scheduler: Optional[_LRScheduler] = None self.cross_validation_split_index = config.cross_validation_split_index self.effective_random_seed = config.get_effective_random_seed() - # Timers for monitoring data loading time - self.train_timers = EpochTimers() - self.val_timers = EpochTimers() # This should be re-assigned on the outside, to a logger that is hooked up with the Trainer object. self.storing_logger = StoringLogger() # This will be initialized correctly in epoch_start @@ -260,9 +404,6 @@ def use_sync_dist(self) -> bool: assert isinstance(self.trainer, Trainer) return self.trainer.accelerator_connector.use_ddp - def on_train_epoch_start(self) -> None: - self.train_timers.reset() - def training_epoch_end(self, outputs: List[Any]) -> None: # Write out all the metrics that have been accumulated in the StoringLogger in the previous epoch. # Metrics for the very last epoch are written in on_train_end @@ -275,10 +416,6 @@ def on_validation_epoch_start(self) -> None: that any randomization when loading validation data is consistent during training. In particular, this ensures that drawing random patches for segmentation model training is giving a validation set that does not fluctuate. """ - self.val_timers.reset() - # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training - # is done for this epoch, even though the on_training_epoch hook has not yet been called. - self.train_timers.epoch_end() # Store the random number generator state, so that the next training epoch starts from here. self.random_state = RandomStateSnapshot.snapshot_random_state() # reset the random state for validation, so that we get consistent behaviour when drawing random patches @@ -286,9 +423,6 @@ def on_validation_epoch_start(self) -> None: seed = self.effective_random_seed set_random_seed(seed, "Validation") - def on_validation_epoch_end(self) -> None: - self.val_timers.epoch_end() - def validation_epoch_end(self, outputs: List[Any]) -> None: """ Resets the random number generator state to what it was before the current validation epoch started. @@ -319,45 +453,6 @@ def read_epoch_results_from_logger_and_store(self, epoch: int) -> None: metrics = self.storing_logger.extract_by_prefix(epoch, prefix) self.store_epoch_results(metrics, epoch, is_training) - @rank_zero_only - def training_or_validation_epoch_end(self, is_training: bool) -> None: - """ - This is a hook called at the end of a training or validation epoch. In here, we can still write - metrics to a logger. - :param is_training: If True, this is called at the end of a training epoch. If False, this is at the - end of a validation epoch. - """ - if not is_training: - # In validation epochs, mark that it has been completed. Training epochs are marked completed already - # at the start of the validation epoch. - self.val_timers.epoch_end() - # Write all IO stats here, so that the order on the console is Train start, train end, val start, val end. - self.write_and_log_epoch_time(is_training=True) - self.write_and_log_epoch_time(is_training=False) - - def write_and_log_epoch_time(self, is_training: bool) -> None: - """ - Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the - time per epoch. - :param is_training: If True, show and log the data for the training epoch. If False, use the data for the - validation epoch. - """ - timers = self.get_timers(is_training=is_training) - epoch_time_seconds = timers.total_epoch_time - status = "training" if is_training else "validation" - logging.info(f"Epoch {self.current_epoch} {status} took {epoch_time_seconds:0.2f}sec, of which waiting for " - f"data took {timers.total_load_time:0.2f} sec total.") - if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch: - logging.warning("The dataloaders were not fast enough to always supply the next batch in less than " - f"{MAX_ITEM_LOAD_TIME_SEC}sec.") - logging.warning( - f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load " - f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.") - # This metric is only written at rank zero, and hence must no be synchronized across workers. If attempted, - # training will get stuck. - self.log_on_epoch(MetricType.SECONDS_PER_EPOCH, epoch_time_seconds, is_training=is_training, - sync_dist_override=False) - def log_on_epoch(self, name: Union[MetricType, str], value: Any, @@ -414,18 +509,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.checkpoint_loading_message = f"Loading checkpoint that was created at ({', '.join(present_keys)})" logging.info(self.checkpoint_loading_message) - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - self.batch_start(batch_idx=batch_idx, is_training=True) - - def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - self.batch_start(batch_idx=batch_idx, is_training=False) - - def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - self.batch_end(is_training=True) - - def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - self.batch_end(is_training=False) - def training_step(self, # type: ignore sample: Dict[str, Any], batch_index: int) -> Any: @@ -450,45 +533,6 @@ def training_or_validation_step(self, """ raise NotImplementedError("This method must be overwritten in a derived class.") - @rank_zero_only - def batch_start(self, batch_idx: int, is_training: bool) -> None: - """ - Shared code to keep track of IO-related metrics when loading a minibatch. This is only done on rank zero. - :param batch_idx: The index of the current minibatch. - :param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from - `on_validation_batch_start`. - :return: - """ - timers = self.get_timers(is_training=is_training) - message_prefix = f"Epoch {self.current_epoch} {'training' if is_training else 'validation'}" - timers.batch_start(batch_index=batch_idx, epoch=self.current_epoch, message_prefix=message_prefix) - - @rank_zero_only - def batch_end(self, is_training: bool) -> None: - """ - Shared code to keep track of IO-related metrics when loading a minibatch. - :param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from - `on_validation_batch_end`. - """ - timers = self.get_timers(is_training=is_training) - batch_time = timers.batch_end() - # This metric is only written at rank 0, and hence must not be synchronized. Trying to synchronize will - # block training. - self.log_on_epoch(MetricType.SECONDS_PER_BATCH, batch_time, is_training=is_training, sync_dist_override=False) - - def get_timers(self, is_training: bool) -> EpochTimers: - """ - Gets the object that holds all IO-related metrics and timers, for either the validation or the training epoch. - """ - return self.train_timers if is_training else self.val_timers - - def reset_timers(self) -> None: - """ - Resets all timers and counters for IO-related metrics, for both the validation and the training epoch. - """ - self.train_timers.reset() - self.val_timers.reset() - def write_loss(self, is_training: bool, loss: torch.Tensor) -> None: """ Writes the given loss value to Lightning, labelled either "val/loss" or "train/loss". diff --git a/InnerEye/ML/lightning_models.py b/InnerEye/ML/lightning_models.py index f8ed4fd97..5419c1efe 100644 --- a/InnerEye/ML/lightning_models.py +++ b/InnerEye/ML/lightning_models.py @@ -165,7 +165,6 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None: for name, value in voxel_count.compute_all(): self.log(name, value) voxel_count.reset() - super().training_or_validation_epoch_end(is_training=is_training) def get_subject_output_file_per_rank(rank: int) -> str: @@ -292,7 +291,6 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None: metric.reset() logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore logger.flush() - super().training_or_validation_epoch_end(is_training) def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: # type: ignore """ diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 287f4a0c1..0b64dcc62 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Optional, Tuple, TypeVar from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.plugins import DDPPlugin @@ -19,7 +19,7 @@ from InnerEye.Common.resource_monitor import ResourceMonitor from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, create_best_checkpoint from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER -from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning +from InnerEye.ML.lightning_base import BatchTimeCallback, InnerEyeContainer, InnerEyeLightning from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \ @@ -113,18 +113,7 @@ def create_lightning_trainer(container: LightningContainer, :param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer. :return: A tuple [Trainer object, diagnostic logger] """ - # For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation - # models, this still appears to be the best way of choosing them because validation loss on the relatively small - # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but - # not for the HeadAndNeck model. - last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0) - - # Recovery checkpoints: {epoch} will turn into a string like "epoch=1" - # Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last - # recovery_checkpoints_save_last_k. - recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container) - - num_gpus = container.num_gpus_per_node + num_gpus = container.num_gpus_per_node() effective_num_gpus = num_gpus * num_nodes # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory). if effective_num_gpus > 1: @@ -157,9 +146,26 @@ def create_lightning_trainer(container: LightningContainer, else: deterministic = False benchmark = True - # If the users provides additional callbacks via get_trainer_arguments (for custom - # containers - callbacks = [last_checkpoint_callback, recovery_checkpoint_callback] + + # For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation + # models, this still appears to be the best way of choosing them because validation loss on the relatively small + # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but + # not for the HeadAndNeck model. + last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0) + # Recovery checkpoints: {epoch} will turn into a string like "epoch=1" + # Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last + # recovery_checkpoints_save_last_k. + recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container) + callbacks = [ + last_checkpoint_callback, + recovery_checkpoint_callback, + BatchTimeCallback() + ] + # TODO: Add a flag for that. + if num_gpus > 0: + logging.info("Adding monitoring for GPU utilization") + callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True)) + # Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers if "callbacks" in kwargs: callbacks.append(kwargs.pop("callbacks")) # type: ignore is_azureml_run = not is_offline_run_context(RUN_CONTEXT) @@ -186,6 +192,7 @@ def create_lightning_trainer(container: LightningContainer, precision=precision, sync_batchnorm=True, terminate_on_nan=container.detect_anomaly, + profiler="simple", resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None, **kwargs) return trainer, storing_logger From 7b5997a01aef5461ef7f97ed33e0dc83328d948e Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 13 Oct 2021 23:42:16 +0100 Subject: [PATCH 03/41] fixing timing callback --- InnerEye/Common/metrics_constants.py | 1 + InnerEye/ML/lightning_base.py | 96 +++++++++++++++++----------- InnerEye/ML/metrics.py | 58 +++++++++++------ Tests/ML/test_model_training.py | 11 +++- 4 files changed, 106 insertions(+), 60 deletions(-) diff --git a/InnerEye/Common/metrics_constants.py b/InnerEye/Common/metrics_constants.py index db7b16502..4961e98fe 100644 --- a/InnerEye/Common/metrics_constants.py +++ b/InnerEye/Common/metrics_constants.py @@ -102,6 +102,7 @@ class MetricType(Enum): # Common metrics SECONDS_PER_BATCH = "SecondsPerBatch" SECONDS_PER_EPOCH = "SecondsPerEpoch" + EXCESS_BATCH_LOADING_TIME = "TotalExcessLoadingTimeSeconds" SUBJECT_COUNT = "SubjectCount" LEARNING_RATE = "LearningRate" diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 33c7eee88..0781a1416 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -25,7 +25,7 @@ WorkflowParams from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.lightning_loggers import StoringLogger -from InnerEye.ML.metrics import EpochTimers, MAX_ITEM_LOAD_TIME_SEC, store_epoch_metrics +from InnerEye.ML.metrics import EpochTimers, store_epoch_metrics from InnerEye.ML.metrics_dict import DataframeLogger from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.utils import model_util @@ -203,7 +203,9 @@ def load_checkpoint_and_modify(self, path_to_checkpoint: Path) -> Dict[str, Any] class BatchTimeCallback(Callback): """ - This class provides tools to measure batch loading time and other diagnostic information. + This callback provides tools to measure batch loading time and other diagnostic information. + It prints alerts if the batch loading time is over a threshold for several epochs. + All logging will only happen on rank 0. """ def __init__(self) -> None: @@ -212,18 +214,16 @@ def __init__(self) -> None: self.val_timers = EpochTimers() def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + This is called at the start of training. It stores the model that is being trained, because it will be used + later to log values. + """ self.module = pl_module - self.logger = trainer.logger def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self.train_timers.reset() def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """ - Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure - that any randomization when loading validation data is consistent during training. In particular, this ensures - that drawing random patches for segmentation model training is giving a validation set that does not fluctuate. - """ self.val_timers.reset() # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training # is done for this epoch, even though the on_training_epoch hook has not yet been called. @@ -279,6 +279,35 @@ def on_validation_batch_end(self, ) -> None: self.batch_end(is_training=False) + def batch_start(self, batch_idx: int, is_training: bool) -> None: + """ + Shared code to keep track of minibatch loading times. This is only done on rank zero. + :param batch_idx: The index of the current minibatch. + :param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from + `on_validation_batch_start`. + """ + timers = self.get_timers(is_training=is_training) + epoch = self.module.current_epoch + message_prefix = f"Epoch {epoch} {'training' if is_training else 'validation'}" + timers.batch_start(batch_index=batch_idx, epoch=epoch, message_prefix=message_prefix) + + def batch_end(self, is_training: bool) -> None: + """ + Shared code to keep track of IO-related metrics when loading a minibatch. + :param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from + `on_validation_batch_end`. + """ + timers = self.get_timers(is_training=is_training) + batch_time = timers.batch_end() + self.log_metric(MetricType.SECONDS_PER_BATCH.value, + value=batch_time, + is_training=is_training) + self.log_metric(MetricType.SECONDS_PER_BATCH.value + " max", + value=batch_time, + is_training=is_training, + reduce_max=True) + + @rank_zero_only def write_and_log_epoch_time(self, is_training: bool) -> None: """ Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the @@ -293,45 +322,36 @@ def write_and_log_epoch_time(self, is_training: bool) -> None: f"for data took {timers.total_load_time:0.2f} sec total.") if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch: logging.warning("The dataloaders were not fast enough to always supply the next batch in less than " - f"{MAX_ITEM_LOAD_TIME_SEC}sec.") + f"{timers.max_item_load_time_seconds}sec.") logging.warning( f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load " f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.") # This metric is only written at rank zero, and hence must not be synchronized across workers. If attempted, # training will get stuck. - self.log_metric(MetricType.SECONDS_PER_EPOCH, value=epoch_time_seconds, is_training=is_training) - - def log_metric(self, metric_type: MetricType, value: float, is_training: bool) -> None: - # Metrics are only written at rank 0, and hence must not be synchronized. Trying to synchronize will - # block training. - prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX - self.module.log(name=prefix + metric_type.value, value=value, - on_step=False, on_epoch=True, sync_dist=False) - - @rank_zero_only - def batch_start(self, batch_idx: int, is_training: bool) -> None: - """ - Shared code to keep track of IO-related metrics when loading a minibatch. This is only done on rank zero. - :param batch_idx: The index of the current minibatch. - :param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from - `on_validation_batch_start`. - :return: - """ - timers = self.get_timers(is_training=is_training) - epoch = self.module.current_epoch - message_prefix = f"Epoch {epoch} {'training' if is_training else 'validation'}" - timers.batch_start(batch_index=batch_idx, epoch=epoch, message_prefix=message_prefix) + self.log_metric(MetricType.SECONDS_PER_EPOCH.value, + value=epoch_time_seconds, + is_training=is_training) + self.log_metric(MetricType.EXCESS_BATCH_LOADING_TIME.value, + value=timers.total_extra_load_time, + is_training=is_training) @rank_zero_only - def batch_end(self, is_training: bool) -> None: + def log_metric(self, name_suffix: str, value: float, is_training: bool, reduce_max: bool = False) -> None: """ - Shared code to keep track of IO-related metrics when loading a minibatch. - :param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from - `on_validation_batch_end`. + Write a metric given as a name/value pair to the currently trained module. The full name of the metric is + composed of a fixed prefix "timing/", followed by either "train/" or "val/", and then the given suffix. + :param name_suffix: The suffix for the logged metric name. + :param value: The value to log. + :param is_training: If True, use "train/" in the metric name, otherwise "val/" + :param reduce_max: If True, use torch.max as the aggregation function for the logged values. If False, use + torch.mean """ - timers = self.get_timers(is_training=is_training) - batch_time = timers.batch_end() - self.log_metric(MetricType.SECONDS_PER_BATCH, value=batch_time, is_training=is_training) + # Metrics are only written at rank 0, and hence must not be synchronized. Trying to synchronize will + # block training. + prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX + self.module.log(name="timing/" + prefix + name_suffix, value=value, + on_step=False, on_epoch=True, sync_dist=False, + reduce_fx=max if reduce_max else torch.mean) def get_timers(self, is_training: bool) -> EpochTimers: """ diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index a87dde368..aa215a08e 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -32,10 +32,6 @@ from InnerEye.ML.utils.ml_util import check_size_matches from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels -MAX_ITEM_LOAD_TIME_SEC = 0.5 -MAX_LOAD_TIME_WARNINGS = 3 -MAX_LOAD_TIME_EPOCHS = 5 - @dataclass(frozen=True) class InferenceMetrics: @@ -81,20 +77,35 @@ def log_metrics(self, run_context: Run = None) -> None: }) -@dataclass class EpochTimers: """ Contains all information necessary to compute the IO metrics: Epoch times, batch times, loading times. """ - epoch_start_time: float = time.time() - epoch_end_time: float = time.time() - batch_start_time: float = time.time() - num_load_time_warnings: int = 0 - num_load_time_exceeded: int = 0 - total_extra_load_time: float = 0.0 - total_load_time: float = 0.0 - num_batches: int = 0 - load_time_warning_epochs: Set[int] = field(default_factory=set) + + def __init__(self, + max_item_load_time_seconds: float = 0.5, + max_load_time_warnings: int = 3, + max_load_time_epochs: int = 5 + ) -> None: + """ + Creates a new instance of the class. + :param max_item_load_time_seconds: The maximum expected loading time for a minibatch (given in seconds). + If the loading time exceeds this threshold, a warning is printed. + :param max_load_time_warnings: The maximum number of warnings that will be printed per epoch. + :param max_load_time_epochs: The maximum number of epochs where warnings about the loading time are printed. + """ + self.max_item_load_time_seconds = max_item_load_time_seconds + self.max_load_time_warnings = max_load_time_warnings + self.max_load_time_epochs = max_load_time_epochs + self.epoch_start_time: float = time.time() + self.epoch_end_time: float = time.time() + self.batch_start_time: float = time.time() + self.num_load_time_warnings: int = 0 + self.num_load_time_exceeded: int = 0 + self.total_extra_load_time: float = 0.0 + self.total_load_time: float = 0.0 + self.num_batches: int = 0 + self.load_time_warning_epochs: Set[int] = set() def reset(self) -> None: """ @@ -128,15 +139,20 @@ def total_epoch_time(self) -> float: def should_warn_in_this_epoch(self) -> bool: """ Returns True if warnings about loading time should be printed in the present epoch. Returns False if - this warning has been printed already in more than MAX_LOAD_TIME_EPOCHS epochs. + this warning has been printed already in more than self.max_load_time_epochs epochs. :return: """ - return len(self.load_time_warning_epochs) <= MAX_LOAD_TIME_EPOCHS + return len(self.load_time_warning_epochs) <= self.max_load_time_epochs def batch_start(self, batch_index: int, epoch: int, message_prefix: str) -> float: """ - Called when a minibatch of data has been loaded. This computes the time it took to load the minibatch, - and adds it to the internal bookkeeping. + Called when a minibatch of data has been loaded. This computes the time it took to load the minibatch + (computed between now and the end of the previous minibatch) + and adds it to the internal bookkeeping. If the minibatch loading time exceeds a threshold, then warnings + are printed (unless too many warnings have been printed already) + :param message_prefix: A prefix string that is added to all diagnostic output. + :param epoch: The index of the current epoch. + :param batch_index: The index of the current minibatch. :return: The time it took to load the minibatch, in seconds. """ item_finish_time = time.time() @@ -146,15 +162,15 @@ def batch_start(self, batch_index: int, epoch: int, message_prefix: str) -> floa # are spawned. Later, the load time should be zero. if batch_index == 0: logging.info(f"{message_prefix}: Loaded the first minibatch of data in {item_load_time:0.2f} sec.") - elif item_load_time > MAX_ITEM_LOAD_TIME_SEC: + elif item_load_time > self.max_item_load_time_seconds: self.load_time_warning_epochs.add(epoch) self.num_load_time_exceeded += 1 self.total_extra_load_time += item_load_time - if self.num_load_time_warnings < MAX_LOAD_TIME_WARNINGS and self.should_warn_in_this_epoch: + if self.num_load_time_warnings < self.max_load_time_warnings and self.should_warn_in_this_epoch: logging.warning(f"{message_prefix}: Loading minibatch {batch_index} took {item_load_time:0.2f} sec. " "This can mean that there are not enough data loader worker processes, or that there " "is a performance problem in loading. This warning will be printed at most " - f"{MAX_LOAD_TIME_WARNINGS} times in at most {MAX_LOAD_TIME_EPOCHS} epochs.") + f"{self.max_load_time_warnings} times in at most {self.max_load_time_epochs} epochs.") self.num_load_time_warnings += 1 return item_load_time diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index 413de42e0..16c945f95 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -17,7 +17,7 @@ from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, is_windows, logging_to_stdout from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path -from InnerEye.Common.metrics_constants import MetricType, TrackedMetrics, VALIDATION_PREFIX +from InnerEye.Common.metrics_constants import MetricType, TRAIN_PREFIX, TrackedMetrics, VALIDATION_PREFIX from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_SUFFIX, DATASET_CSV_FILE_NAME, \ ModelExecutionMode, \ @@ -114,6 +114,15 @@ def _mean_list(lists: List[List[float]]) -> List[float]: model_training_result, _ = model_train_unittest(train_config, dirs=output_dirs) assert isinstance(model_training_result, StoringLogger) + for epoch, epoch_results in model_training_result.results.items(): + for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]: + for metric_type in [MetricType.SECONDS_PER_EPOCH.value, + MetricType.SECONDS_PER_BATCH.value, + MetricType.EXCESS_BATCH_LOADING_TIME.value, + MetricType.SECONDS_PER_BATCH.value + " max"]: + expected = "timing/" + prefix + metric_type + assert expected in epoch_results, f"Expected {expected} in results for epoch {epoch}" + assert epoch_results[expected] > 0.0, "Time should be > 0" actual_train_losses = model_training_result.get_train_metric(MetricType.LOSS.value) actual_val_losses = model_training_result.get_val_metric(MetricType.LOSS.value) From 9d55e20621a1382bf6fd2a764e60bd89dd90eed4 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 13 Oct 2021 23:46:03 +0100 Subject: [PATCH 04/41] docu --- InnerEye/ML/lightning_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 0781a1416..2b41fb6da 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -206,6 +206,11 @@ class BatchTimeCallback(Callback): This callback provides tools to measure batch loading time and other diagnostic information. It prints alerts if the batch loading time is over a threshold for several epochs. All logging will only happen on rank 0. + + The loading time for a minibatch is estimated by the difference between the start time of a minibatch and the + end time of the previous minibatch. It will consequently also include other operations that happen between the + end of a batch and the start of the next one. For example, computationally expensive callbacks could also + drive up this time. """ def __init__(self) -> None: From 4254ff2f459c95b4e97f9c2baaa7b4a3e43c82ca Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 13 Oct 2021 23:47:19 +0100 Subject: [PATCH 05/41] docu --- InnerEye/ML/lightning_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 2b41fb6da..8b49059ca 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -360,13 +360,13 @@ def log_metric(self, name_suffix: str, value: float, is_training: bool, reduce_m def get_timers(self, is_training: bool) -> EpochTimers: """ - Gets the object that holds all IO-related metrics and timers, for either the validation or the training epoch. + Gets the object that holds all metrics and timers, for either the validation or the training epoch. """ return self.train_timers if is_training else self.val_timers def reset_timers(self) -> None: """ - Resets all timers and counters for IO-related metrics, for both the validation and the training epoch. + Resets all timers and counters, for both the validation and the training epoch. """ self.train_timers.reset() self.val_timers.reset() From 0a1fc2600bda9ed15d34351712528b43ced008c8 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 14 Oct 2021 15:51:49 +0100 Subject: [PATCH 06/41] progress bar --- InnerEye/ML/lightning_base.py | 2 +- InnerEye/ML/lightning_loggers.py | 112 ++++++++++++++++++++++++++++++- InnerEye/ML/model_training.py | 20 ++++-- 3 files changed, 125 insertions(+), 9 deletions(-) diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 8b49059ca..9329abaad 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -304,7 +304,7 @@ def batch_end(self, is_training: bool) -> None: """ timers = self.get_timers(is_training=is_training) batch_time = timers.batch_end() - self.log_metric(MetricType.SECONDS_PER_BATCH.value, + self.log_metric(MetricType.SECONDS_PER_BATCH.value + " avg", value=batch_time, is_training=is_training) self.log_metric(MetricType.SECONDS_PER_BATCH.value + " max", diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index 073bd3065..1ddbd1af2 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -2,8 +2,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import logging +import math +import sys +import time from typing import Any, Dict, Iterable, List, Optional +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ProgressBarBase from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only @@ -152,9 +158,12 @@ def __init__(self) -> None: @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + is_epoch_metric = "epoch" in metrics if self.is_azureml_run: for key, value in metrics.items(): - RUN_CONTEXT.log(key, value) + # Log all epoch-level metrics without the step information + # All step-level metrics with step + RUN_CONTEXT.log(key, value, step=None if is_epoch_metric else step) @rank_zero_only def log_hyperparams(self, params: Any) -> None: @@ -168,3 +177,104 @@ def name(self) -> Any: def version(self) -> int: return 0 + + +class AzureMLProgressBar(ProgressBarBase): + """ + A PL progress bar that works better in AzureML. It prints timestamps for each message, and works well with a setup + where there is no direct access to the console. + """ + def __init__(self, refresh_rate: Optional[int] = 1): + super().__init__() + self._refresh_rate = refresh_rate or 1 + self._enabled = True + self.stage = "" + self.stage_start_time = 0.0 + self.max_batch_count = 0 + + @property + def refresh_rate(self) -> int: + return self._refresh_rate + + @property + def is_enabled(self) -> bool: + return self._enabled and self.refresh_rate > 0 + + @property + def is_disabled(self) -> bool: + return not self.is_enabled + + def disable(self) -> None: + self._enabled = False + + def enable(self) -> None: + self._enabled = True + + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.module = pl_module + + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) + self.start_stage("Training", self.total_train_batches) + + def on_validation_epoch_start(self, trainer, pl_module): + super().on_validation_epoch_start(trainer, pl_module) + self.start_stage("Validation", self.total_val_batches) + + def on_test_epoch_start(self, trainer, pl_module): + super().on_test_epoch_start(trainer, pl_module) + self.start_stage("Testing", self.total_test_batches) + + def on_predict_epoch_start(self, trainer, pl_module): + super().on_predict_epoch_start(trainer, pl_module) + self.start_stage("Prediction", self.total_predict_batches) + + def start_stage(self, stage: str, max_batch_count: int): + self.stage = stage + self.max_batch_count = max_batch_count + self.stage_start_time = time.time() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + self.update_progress(batches_processed = self.train_batch_idx) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + self.update_progress(batches_processed = self.val_batch_idx) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + self.update_progress(batches_processed = self.test_batch_idx) + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + self.update_progress(batches_processed = self.predict_batch_idx) + + def update_progress(self, batches_processed: int): + should_update = self.is_enabled and \ + (batches_processed % self.refresh_rate == 0 or batches_processed == self.max_batch_count) + if not should_update: + return + prefix = f"{self.stage}" + if self.stage in ["Training", "Validation"]: + prefix += f" epoch {self.module.current_epoch}" + if self.stage == "Training": + prefix += f" (step {self.module.global_step})" + prefix += ": " + if math.isinf(self.max_batch_count): + # Can't print out per-cent progress or time estimates if the data is infinite + logging.info(f"{prefix}{batches_processed} batches completed") + else: + fraction_completed = batches_processed / self.max_batch_count + percent_completed = int(fraction_completed * 100) + time_elapsed = time.time() - self.stage_start_time + estimated_epoch_time = time_elapsed / fraction_completed + + def to_minutes(time_sec: float) -> str: + minutes = int(time_sec / 60) + seconds = int(time_sec % 60) + return f"{minutes:02}:{seconds:02}" + + logging.info(f"{prefix}{batches_processed:4}/{self.max_batch_count} ({percent_completed:3}%) completed. " + f"{to_minutes(time_elapsed)} done, epoch expected to take {to_minutes(estimated_epoch_time)}") + sys.stdout.flush() diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 0b64dcc62..9f4f44db3 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -21,7 +21,7 @@ from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER from InnerEye.ML.lightning_base import BatchTimeCallback, InnerEyeContainer, InnerEyeLightning from InnerEye.ML.lightning_container import LightningContainer -from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger +from InnerEye.ML.lightning_loggers import AzureMLLogger, AzureMLProgressBar, StoringLogger from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \ get_subject_output_file_per_rank @@ -167,14 +167,20 @@ def create_lightning_trainer(container: LightningContainer, callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True)) # Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers if "callbacks" in kwargs: - callbacks.append(kwargs.pop("callbacks")) # type: ignore + more_callbacks = kwargs.pop("callbacks") + if isinstance(more_callbacks, list): + callbacks.extend(more_callbacks) # type: ignore + else: + callbacks.append(more_callbacks) # type: ignore is_azureml_run = not is_offline_run_context(RUN_CONTEXT) progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate - if progress_bar_refresh_rate is None and is_azureml_run: - # When running in AzureML, the default progress bar clutters the output files with thousands of lines. - progress_bar_refresh_rate = 50 - logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. " - f"To change, modify the pl_progress_bar_refresh_rate field of the container.") + if is_azureml_run: + if progress_bar_refresh_rate is None: + # When running in AzureML, the default progress bar clutters the output files with thousands of lines. + progress_bar_refresh_rate = 50 + logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. " + f"To change, modify the pl_progress_bar_refresh_rate field of the container.") + callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate)) # Read out additional model-specific args here. # We probably want to keep essential ones like numgpu and logging. trainer = Trainer(default_root_dir=str(container.outputs_folder), From f190669bb7090615bc34f8ed7e295a72e0161ecc Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 14 Oct 2021 20:47:22 +0100 Subject: [PATCH 07/41] docu and cleanup --- InnerEye/ML/lightning_loggers.py | 89 ++++++++++++++++++++++---------- InnerEye/ML/model_training.py | 3 +- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index 1ddbd1af2..ad471a289 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -179,18 +179,36 @@ def version(self) -> int: return 0 +PROGRESS_STAGE_TRAIN = "Training" +PROGRESS_STAGE_VAL = "Validation" +PROGRESS_STAGE_TEST = "Testing" +PROGRESS_STAGE_PREDICT = "Prediction" + + class AzureMLProgressBar(ProgressBarBase): """ A PL progress bar that works better in AzureML. It prints timestamps for each message, and works well with a setup where there is no direct access to the console. """ - def __init__(self, refresh_rate: Optional[int] = 1): + + def __init__(self, + refresh_rate: int = 50, + write_to_logging_info: bool = False + ): + """ + Creates a new AzureML progress bar. + :param refresh_rate: The number of steps after which the progress should be printed out. + :param write_to_logging_info: If True, the progress information will be printed via logging.info. If False, + it will be printed to stdout via print. + """ super().__init__() - self._refresh_rate = refresh_rate or 1 + self._refresh_rate = refresh_rate self._enabled = True self.stage = "" self.stage_start_time = 0.0 self.max_batch_count = 0 + self.progress_print_fn = logging.info if write_to_logging_info else print + self.flush_fn = None if write_to_logging_info else sys.stdout.flush @property def refresh_rate(self) -> int: @@ -213,68 +231,85 @@ def enable(self) -> None: def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self.module = pl_module - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_epoch_start(trainer, pl_module) - self.start_stage("Training", self.total_train_batches) + self.start_stage(PROGRESS_STAGE_TRAIN, self.total_train_batches) - def on_validation_epoch_start(self, trainer, pl_module): + def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_validation_epoch_start(trainer, pl_module) - self.start_stage("Validation", self.total_val_batches) + self.start_stage(PROGRESS_STAGE_VAL, self.total_val_batches) - def on_test_epoch_start(self, trainer, pl_module): + def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_test_epoch_start(trainer, pl_module) - self.start_stage("Testing", self.total_test_batches) + self.start_stage(PROGRESS_STAGE_TEST, self.total_test_batches) - def on_predict_epoch_start(self, trainer, pl_module): + def on_predict_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_predict_epoch_start(trainer, pl_module) - self.start_stage("Prediction", self.total_predict_batches) + self.start_stage(PROGRESS_STAGE_PREDICT, self.total_predict_batches) - def start_stage(self, stage: str, max_batch_count: int): + def start_stage(self, stage: str, max_batch_count: int) -> None: + """ + Sets the information that a new stage of the PL loop is starting. The stage will be available in + self.stage, max_batch_count in self.max_batch_count. The time when this method was called is recorded in + self.stage_start_time + :param stage: The string name of the stage that has just started. + :param max_batch_count: The total number of batches that need to be processed in this stage. + """ self.stage = stage self.max_batch_count = max_batch_count self.stage_start_time = time.time() - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, + batch_idx: int, dataloader_idx: int) -> None: super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed = self.train_batch_idx) + self.update_progress(batches_processed=self.train_batch_idx) - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, + batch_idx: int, dataloader_idx: int) -> None: super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed = self.val_batch_idx) + self.update_progress(batches_processed=self.val_batch_idx) - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, + batch_idx: int, dataloader_idx: int) -> None: super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed = self.test_batch_idx) + self.update_progress(batches_processed=self.test_batch_idx) - def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_predict_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, + batch_idx: int, dataloader_idx: int) -> None: super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed = self.predict_batch_idx) + self.update_progress(batches_processed=self.predict_batch_idx) def update_progress(self, batches_processed: int): + """ + Writes progress information once the refresh interval is full. + :param batches_processed: The number of batches that have been processed for the current stage. + """ should_update = self.is_enabled and \ (batches_processed % self.refresh_rate == 0 or batches_processed == self.max_batch_count) if not should_update: return prefix = f"{self.stage}" - if self.stage in ["Training", "Validation"]: - prefix += f" epoch {self.module.current_epoch}" - if self.stage == "Training": + if self.stage in [PROGRESS_STAGE_TRAIN, PROGRESS_STAGE_VAL]: + prefix += f" epoch {self.module.current_epoch}" + if self.stage == PROGRESS_STAGE_TRAIN: prefix += f" (step {self.module.global_step})" prefix += ": " if math.isinf(self.max_batch_count): # Can't print out per-cent progress or time estimates if the data is infinite - logging.info(f"{prefix}{batches_processed} batches completed") + message = f"{prefix}{batches_processed} batches completed" else: fraction_completed = batches_processed / self.max_batch_count percent_completed = int(fraction_completed * 100) time_elapsed = time.time() - self.stage_start_time - estimated_epoch_time = time_elapsed / fraction_completed + estimated_epoch_duration = time_elapsed / fraction_completed def to_minutes(time_sec: float) -> str: minutes = int(time_sec / 60) seconds = int(time_sec % 60) return f"{minutes:02}:{seconds:02}" - logging.info(f"{prefix}{batches_processed:4}/{self.max_batch_count} ({percent_completed:3}%) completed. " - f"{to_minutes(time_elapsed)} done, epoch expected to take {to_minutes(estimated_epoch_time)}") - sys.stdout.flush() + message = (f"{prefix}{batches_processed:4}/{self.max_batch_count} ({percent_completed:3}%) completed. " + f"{to_minutes(time_elapsed)} elapsed, total epoch time ~ {to_minutes(estimated_epoch_duration)}") + self.progress_print_fn(message) + if self.flush_fn: + self.flush_fn() diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 9f4f44db3..157fb29cc 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -176,11 +176,10 @@ def create_lightning_trainer(container: LightningContainer, progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate if is_azureml_run: if progress_bar_refresh_rate is None: - # When running in AzureML, the default progress bar clutters the output files with thousands of lines. progress_bar_refresh_rate = 50 logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. " f"To change, modify the pl_progress_bar_refresh_rate field of the container.") - callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate)) + callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate, write_to_logging_info=True)) # Read out additional model-specific args here. # We probably want to keep essential ones like numgpu and logging. trainer = Trainer(default_root_dir=str(container.outputs_folder), From f90912d046ad499cfb2f436c9be8784cd9cfdf0a Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Fri, 15 Oct 2021 15:13:58 +0100 Subject: [PATCH 08/41] tests --- InnerEye/ML/lightning_loggers.py | 13 ++--- Tests/ML/test_loggers.py | 92 ++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 8 deletions(-) create mode 100644 Tests/ML/test_loggers.py diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index ad471a289..1f33ad618 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -228,15 +228,12 @@ def disable(self) -> None: def enable(self) -> None: self._enabled = True - def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self.module = pl_module - def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_epoch_start(trainer, pl_module) self.start_stage(PROGRESS_STAGE_TRAIN, self.total_train_batches) - def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_validation_epoch_start(trainer, pl_module) + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_validation_start(trainer, pl_module) self.start_stage(PROGRESS_STAGE_VAL, self.total_val_batches) def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -290,13 +287,13 @@ def update_progress(self, batches_processed: int): return prefix = f"{self.stage}" if self.stage in [PROGRESS_STAGE_TRAIN, PROGRESS_STAGE_VAL]: - prefix += f" epoch {self.module.current_epoch}" + prefix += f" epoch {self.trainer.current_epoch}" if self.stage == PROGRESS_STAGE_TRAIN: - prefix += f" (step {self.module.global_step})" + prefix += f" (step {self.trainer.lightning_module.global_step})" prefix += ": " if math.isinf(self.max_batch_count): # Can't print out per-cent progress or time estimates if the data is infinite - message = f"{prefix}{batches_processed} batches completed" + message = f"{prefix}{batches_processed:4} batches completed" else: fraction_completed = batches_processed / self.max_batch_count percent_completed = int(fraction_completed * 100) diff --git a/Tests/ML/test_loggers.py b/Tests/ML/test_loggers.py new file mode 100644 index 000000000..0cd68a492 --- /dev/null +++ b/Tests/ML/test_loggers.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +import math +from typing import List +from unittest import mock + +from InnerEye.ML.lightning_loggers import AzureMLProgressBar, PROGRESS_STAGE_PREDICT, PROGRESS_STAGE_TEST, \ + PROGRESS_STAGE_TRAIN, \ + PROGRESS_STAGE_VAL + + +def test_progress_bar_enable() -> None: + """ + Test the logic for disabling the progress bar. + """ + bar = AzureMLProgressBar(refresh_rate=0) + assert not bar.is_enabled + bar = AzureMLProgressBar(refresh_rate=1) + assert bar.is_enabled + bar.disable() + assert not bar.is_enabled + bar.enable() + assert bar.is_enabled + + +def test_progress_bar() -> None: + bar = AzureMLProgressBar(refresh_rate=1) + mock_trainer = mock.MagicMock(current_epoch=12, + lightning_module=mock.MagicMock(global_step=34), + num_training_batches=10, + emable_validation=False, + num_test_batches=[20], + num_predict_batches=[30]) + bar.on_init_end(mock_trainer) # type: ignore + assert bar.trainer == mock_trainer + messages: List[str] = [] + + def write_message(message: str) -> None: + messages.append(message) + + bar.progress_print_fn = write_message + bar.flush_fn = None + # Messages in training + bar.on_train_epoch_start(None, None) # type: ignore + assert bar.stage == PROGRESS_STAGE_TRAIN + assert bar.train_batch_idx == 0 + assert bar.val_batch_idx == 0 + assert bar.test_batch_idx == 0 + assert bar.predict_batch_idx == 0 + bar.on_train_batch_end(None, None, None, None, None, None) # type: ignore + assert bar.train_batch_idx == 1 + assert "Training epoch 12 (step 34)" in messages[-1] + assert "1/10 ( 10%) completed" in messages[-1] + # When starting the next training epoch, the counters should be reset + bar.on_train_epoch_start(None, None) # type: ignore + assert bar.train_batch_idx == 0 + # Messages in validation + bar.on_validation_start(None, None) # type: ignore + assert bar.stage == PROGRESS_STAGE_VAL + assert bar.max_batch_count == 0 + assert bar.val_batch_idx == 0 + # Number of validation batches is difficult to fake, tweak the field where it is stored in the progress bar + bar.max_batch_count = 5 + bar.on_validation_batch_end(None, None, None, None, None, None) # type: ignore + assert bar.val_batch_idx == 1 + assert "Validation epoch 12: " in messages[-1] + assert "1/5 ( 20%) completed" in messages[-1] + # Messages in testing + bar.on_test_epoch_start(None, None) # type: ignore + assert bar.stage == PROGRESS_STAGE_TEST + bar.on_test_batch_end(None, None, None, None, None, None) # type: ignore + bar.on_test_batch_end(None, None, None, None, None, None) # type: ignore + assert bar.test_batch_idx == 2 + assert "Testing:" in messages[-1] + assert "2/20 ( 10%)" in messages[-1] + # Messages in prediction + bar.on_predict_epoch_start(None, None) # type: ignore + assert bar.stage == PROGRESS_STAGE_PREDICT + bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore + bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore + bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore + assert bar.predict_batch_idx == 3 + assert "Prediction:" in messages[-1] + assert "3/30 ( 10%)" in messages[-1] + # Test behaviour when a batch count is infinity + bar.max_batch_count = math.inf + bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore + assert bar.predict_batch_idx == 4 + assert "4 batches completed" in messages[-1] + From 9bdbd7a432974c5f1091a4e462c12329daff3223 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Mon, 18 Oct 2021 09:57:59 +0100 Subject: [PATCH 09/41] test cleanup --- Tests/ML/test_loggers.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/Tests/ML/test_loggers.py b/Tests/ML/test_loggers.py index 0cd68a492..edbdd175c 100644 --- a/Tests/ML/test_loggers.py +++ b/Tests/ML/test_loggers.py @@ -70,20 +70,21 @@ def write_message(message: str) -> None: # Messages in testing bar.on_test_epoch_start(None, None) # type: ignore assert bar.stage == PROGRESS_STAGE_TEST - bar.on_test_batch_end(None, None, None, None, None, None) # type: ignore - bar.on_test_batch_end(None, None, None, None, None, None) # type: ignore - assert bar.test_batch_idx == 2 + test_count = 2 + for _ in range(test_count): + bar.on_test_batch_end(None, None, None, None, None, None) # type: ignore + assert bar.test_batch_idx == test_count assert "Testing:" in messages[-1] - assert "2/20 ( 10%)" in messages[-1] + assert f"{test_count}/20 ( 10%)" in messages[-1] # Messages in prediction bar.on_predict_epoch_start(None, None) # type: ignore assert bar.stage == PROGRESS_STAGE_PREDICT - bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore - bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore - bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore - assert bar.predict_batch_idx == 3 + predict_count = 3 + for _ in range(predict_count): + bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore + assert bar.predict_batch_idx == predict_count assert "Prediction:" in messages[-1] - assert "3/30 ( 10%)" in messages[-1] + assert f"{predict_count}/30 ( 10%)" in messages[-1] # Test behaviour when a batch count is infinity bar.max_batch_count = math.inf bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore From d02ba0be0262a66d6d01099d14878aa4dd485c6f Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 19 Oct 2021 11:36:22 +0100 Subject: [PATCH 10/41] test for timers --- InnerEye/ML/metrics.py | 17 ++++++------ Tests/ML/test_loggers.py | 60 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index aa215a08e..bad2c397a 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -81,6 +81,14 @@ class EpochTimers: """ Contains all information necessary to compute the IO metrics: Epoch times, batch times, loading times. """ + epoch_start_time: float = 0.0 + epoch_end_time: float = 0.0 + batch_start_time: float = 0.0 + num_load_time_warnings: int = 0 + num_load_time_exceeded: int = 0 + total_extra_load_time: float = 0.0 + total_load_time: float = 0.0 + num_batches: int = 0 def __init__(self, max_item_load_time_seconds: float = 0.5, @@ -97,15 +105,8 @@ def __init__(self, self.max_item_load_time_seconds = max_item_load_time_seconds self.max_load_time_warnings = max_load_time_warnings self.max_load_time_epochs = max_load_time_epochs - self.epoch_start_time: float = time.time() - self.epoch_end_time: float = time.time() - self.batch_start_time: float = time.time() - self.num_load_time_warnings: int = 0 - self.num_load_time_exceeded: int = 0 - self.total_extra_load_time: float = 0.0 - self.total_load_time: float = 0.0 - self.num_batches: int = 0 self.load_time_warning_epochs: Set[int] = set() + self.reset() def reset(self) -> None: """ diff --git a/Tests/ML/test_loggers.py b/Tests/ML/test_loggers.py index edbdd175c..8aef6e1c0 100644 --- a/Tests/ML/test_loggers.py +++ b/Tests/ML/test_loggers.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import logging import math from typing import List from unittest import mock @@ -9,6 +10,8 @@ from InnerEye.ML.lightning_loggers import AzureMLProgressBar, PROGRESS_STAGE_PREDICT, PROGRESS_STAGE_TEST, \ PROGRESS_STAGE_TRAIN, \ PROGRESS_STAGE_VAL +from InnerEye.ML.metrics import EpochTimers +from _pytest.logging import LogCaptureFixture def test_progress_bar_enable() -> None: @@ -91,3 +94,60 @@ def write_message(message: str) -> None: assert bar.predict_batch_idx == 4 assert "4 batches completed" in messages[-1] + +def test_epoch_timers(caplog: LogCaptureFixture) -> None: + caplog.set_level(logging.INFO) + batch_index = 123 + epoch = 24 + timer = EpochTimers(max_item_load_time_seconds=100) + assert timer.total_load_time == 0.0 + + # First batch should always generate a message + timer.batch_start(batch_index=0, epoch=epoch, message_prefix="prefix") + assert timer.total_load_time > 0.0 + message = caplog.messages[-1] + assert "prefix: Loaded the first minibatch of data in" in message + old_num_batches = timer.num_batches + old_batch_start_time = timer.batch_start_time + timer.batch_end() + assert timer.num_batches == old_num_batches + 1 + assert timer.batch_start_time > old_batch_start_time + + # Second minibatch should only generate a message if above load time threshold. Set threshold very high + old_num_messages = len(caplog.messages) + old_total_load_time = timer.total_load_time + timer.max_item_load_time_seconds = 10.0 + assert timer.num_load_time_exceeded == 0 + timer.batch_start(batch_index=batch_index, epoch=epoch, message_prefix="prefix") + # This should be updated in any case + assert timer.total_load_time > old_total_load_time + # But this batch should not be recognized as having gone over the threshold + assert timer.num_load_time_exceeded == 0 + assert len(timer.load_time_warning_epochs) == 0 + assert len(caplog.messages) == old_num_messages + assert timer.num_load_time_warnings == 0 + + # Third minibatch considered as above threshold: set threshold to 0 for that + old_total_load_time = timer.total_load_time + timer.max_item_load_time_seconds = 0.0 + timer.batch_start(batch_index=batch_index, epoch=epoch, message_prefix="prefix") + # This should be updated in any case + assert timer.total_load_time > old_total_load_time + # Batch should not be recognized as having gone over the threshold + assert timer.num_load_time_exceeded == 1 + assert epoch in timer.load_time_warning_epochs + message = caplog.messages[-1] + assert f"prefix: Loading minibatch { batch_index} took" in message + assert f"This message will be printed at most {timer.max_load_time_warnings} times" + assert timer.num_load_time_warnings > 0 + + # Epoch end time should be stored + assert timer.total_epoch_time == 0.0 + old_epoch_end_time = timer.epoch_end_time + timer.epoch_end() + assert timer.epoch_end_time > old_epoch_end_time + assert timer.total_epoch_time > 0.0 + + timer.reset() + assert timer.total_load_time == 0.0 + assert timer.num_load_time_warnings == 0 From 8cbc5f102ed44fc0ba0876f9930ceb74a3e27860 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 19 Oct 2021 11:37:05 +0100 Subject: [PATCH 11/41] cleanup --- Tests/ML/test_loggers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Tests/ML/test_loggers.py b/Tests/ML/test_loggers.py index 8aef6e1c0..73bccf9af 100644 --- a/Tests/ML/test_loggers.py +++ b/Tests/ML/test_loggers.py @@ -7,12 +7,12 @@ from typing import List from unittest import mock -from InnerEye.ML.lightning_loggers import AzureMLProgressBar, PROGRESS_STAGE_PREDICT, PROGRESS_STAGE_TEST, \ - PROGRESS_STAGE_TRAIN, \ - PROGRESS_STAGE_VAL -from InnerEye.ML.metrics import EpochTimers from _pytest.logging import LogCaptureFixture +from InnerEye.ML.lightning_loggers import (AzureMLProgressBar, PROGRESS_STAGE_PREDICT, PROGRESS_STAGE_TEST, + PROGRESS_STAGE_TRAIN, PROGRESS_STAGE_VAL) +from InnerEye.ML.metrics import EpochTimers + def test_progress_bar_enable() -> None: """ @@ -137,7 +137,7 @@ def test_epoch_timers(caplog: LogCaptureFixture) -> None: assert timer.num_load_time_exceeded == 1 assert epoch in timer.load_time_warning_epochs message = caplog.messages[-1] - assert f"prefix: Loading minibatch { batch_index} took" in message + assert f"prefix: Loading minibatch {batch_index} took" in message assert f"This message will be printed at most {timer.max_load_time_warnings} times" assert timer.num_load_time_warnings > 0 From 144698a6f3701c8a00843678dfcce7e330a6bcc0 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 19 Oct 2021 13:56:06 +0100 Subject: [PATCH 12/41] tests for callback --- InnerEye/ML/lightning_base.py | 13 ++--- InnerEye/ML/metrics.py | 21 ++++----- Tests/ML/test_loggers.py | 89 ++++++++++++++++++++++++++++++++++- 3 files changed, 100 insertions(+), 23 deletions(-) diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 9329abaad..2baaff4b8 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -226,10 +226,10 @@ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self.module = pl_module def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self.train_timers.reset() + self.train_timers.epoch_start() def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self.val_timers.reset() + self.val_timers.epoch_start() # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training # is done for this epoch, even though the on_training_epoch hook has not yet been called. self.train_timers.epoch_end() @@ -327,7 +327,7 @@ def write_and_log_epoch_time(self, is_training: bool) -> None: f"for data took {timers.total_load_time:0.2f} sec total.") if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch: logging.warning("The dataloaders were not fast enough to always supply the next batch in less than " - f"{timers.max_item_load_time_seconds}sec.") + f"{timers.max_item_load_time_seconds:0.2f}sec.") logging.warning( f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load " f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.") @@ -364,13 +364,6 @@ def get_timers(self, is_training: bool) -> EpochTimers: """ return self.train_timers if is_training else self.val_timers - def reset_timers(self) -> None: - """ - Resets all timers and counters, for both the validation and the training epoch. - """ - self.train_timers.reset() - self.val_timers.reset() - class InnerEyeLightning(LightningModule): """ diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index bad2c397a..6394a832e 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -81,14 +81,6 @@ class EpochTimers: """ Contains all information necessary to compute the IO metrics: Epoch times, batch times, loading times. """ - epoch_start_time: float = 0.0 - epoch_end_time: float = 0.0 - batch_start_time: float = 0.0 - num_load_time_warnings: int = 0 - num_load_time_exceeded: int = 0 - total_extra_load_time: float = 0.0 - total_load_time: float = 0.0 - num_batches: int = 0 def __init__(self, max_item_load_time_seconds: float = 0.5, @@ -106,9 +98,16 @@ def __init__(self, self.max_load_time_warnings = max_load_time_warnings self.max_load_time_epochs = max_load_time_epochs self.load_time_warning_epochs: Set[int] = set() - self.reset() - - def reset(self) -> None: + self.epoch_start_time: float = 0.0 + self.epoch_end_time: float = 0.0 + self.batch_start_time: float = 0.0 + self.num_load_time_warnings: int = 0 + self.num_load_time_exceeded: int = 0 + self.total_extra_load_time: float = 0.0 + self.total_load_time: float = 0.0 + self.num_batches: int = 0 + + def epoch_start(self) -> None: """ Resets all timers to the current time, and all counters to 0. The set of epochs for which warnings about load time were produced will not be reset. diff --git a/Tests/ML/test_loggers.py b/Tests/ML/test_loggers.py index 73bccf9af..57f4c0f0a 100644 --- a/Tests/ML/test_loggers.py +++ b/Tests/ML/test_loggers.py @@ -4,11 +4,14 @@ # ------------------------------------------------------------------------------------------ import logging import math -from typing import List +from typing import Callable, Dict, List, Optional from unittest import mock +import torch from _pytest.logging import LogCaptureFixture +from InnerEye.Common.metrics_constants import MetricType, TRAIN_PREFIX, VALIDATION_PREFIX +from InnerEye.ML.lightning_base import BatchTimeCallback from InnerEye.ML.lightning_loggers import (AzureMLProgressBar, PROGRESS_STAGE_PREDICT, PROGRESS_STAGE_TEST, PROGRESS_STAGE_TRAIN, PROGRESS_STAGE_VAL) from InnerEye.ML.metrics import EpochTimers @@ -96,6 +99,9 @@ def write_message(message: str) -> None: def test_epoch_timers(caplog: LogCaptureFixture) -> None: + """ + Test the class that measures batch and epoch times. + """ caplog.set_level(logging.INFO) batch_index = 123 epoch = 24 @@ -140,6 +146,10 @@ def test_epoch_timers(caplog: LogCaptureFixture) -> None: assert f"prefix: Loading minibatch {batch_index} took" in message assert f"This message will be printed at most {timer.max_load_time_warnings} times" assert timer.num_load_time_warnings > 0 + # Test if the warnings disappear after the max number of warnings + assert timer.should_warn_in_this_epoch + timer.num_load_time_warnings = timer.max_load_time_warnings + 1 + assert not timer.should_warn_in_this_epoch # Epoch end time should be stored assert timer.total_epoch_time == 0.0 @@ -148,6 +158,81 @@ def test_epoch_timers(caplog: LogCaptureFixture) -> None: assert timer.epoch_end_time > old_epoch_end_time assert timer.total_epoch_time > 0.0 - timer.reset() + # Test the resetting logic + timer.epoch_start() assert timer.total_load_time == 0.0 assert timer.num_load_time_warnings == 0 + # The object should keep track of all epochs in which warnings were printed + assert len(timer.load_time_warning_epochs) > 0 + + +def test_batch_time_callback(caplog: LogCaptureFixture) -> None: + """ + Test the callback that measures data loading times. + """ + caplog.set_level(logging.INFO) + callback = BatchTimeCallback() + epoch = 1234 + # This dictionary stores all metrics that are written via module.log + logged_metrics = {} + + def mock_log(name: str, value: float, reduce_fx: Callable, **kwargs: Dict) -> None: + logged_metrics[name] = (value, reduce_fx) + + mock_module = mock.MagicMock(current_epoch=epoch, log=mock_log) + callback.on_fit_start(trainer=None, pl_module=mock_module) # type: ignore + assert callback.module == mock_module + + # Upon epoch start, the timers should be reset. We can check that by looking at epoch_start_time + assert callback.train_timers.epoch_start_time == 0.0 + callback.on_train_epoch_start(None, None) # type: ignore + assert callback.train_timers.epoch_start_time > 0.0 + assert callback.val_timers.epoch_start_time == 0.0 + old_train_epoch_end_time = callback.train_timers.epoch_end_time + callback.on_validation_epoch_start(None, None) # type: ignore + assert callback.val_timers.epoch_start_time > 0.0 + # When calling epoch_start for validation, training epoch should be ended + assert callback.train_timers.epoch_end_time > old_train_epoch_end_time + + # Run 1 training batch + callback.on_train_batch_start(None, None, None, batch_idx=0, dataloader_idx=0) # type: ignore + callback.on_train_batch_end(None, None, None, None, batch_idx=0, dataloader_idx=0) # type: ignore + assert len(logged_metrics) == 2 + # Upon batch end, we should see metrics being logged. Batch level timings should be logged both as averages and max + def check_batch_metrics(train_or_val: str) -> None: + for suffix in [" avg", " max"]: + name = f"timing/{train_or_val}/SecondsPerBatch" + suffix + assert name in logged_metrics + assert logged_metrics[name][1] == max if suffix == " max" else torch.mean + check_batch_metrics("train") + assert caplog.messages[-1].startswith(f"Epoch {epoch} training: Loaded the first") + # Run 2 validation batches + for batch_idx in range(2): + callback.on_validation_batch_start(None, None, None, batch_idx=batch_idx, dataloader_idx=0) # type: ignore + callback.on_validation_batch_end(None, None, None, None, batch_idx=batch_idx, dataloader_idx=0) # type: ignore + assert caplog.messages[-1].startswith(f"Epoch {epoch} validation: Loaded the first") + assert callback.train_timers.num_batches == 1 + assert callback.val_timers.num_batches == 2 + check_batch_metrics("val") + + # Check that the metrics are written at the end of the validation epoch. + # Hack the timers to trigger the warning message for validation only + callback.val_timers.num_load_time_exceeded = 1 + callback.val_timers.total_extra_load_time = 100.00 + callback.val_timers.max_item_load_time_seconds = 2.0 + assert callback.val_timers.should_warn_in_this_epoch + old_val_epoch_end_time = callback.train_timers.epoch_end_time + callback.on_validation_epoch_end(None, None) # type: ignore + assert callback.val_timers.epoch_end_time > old_val_epoch_end_time + assert len(logged_metrics) > 0 + + assert f"Epoch {epoch} training took " in caplog.messages[-4] + assert f"Epoch {epoch} validation took " in caplog.messages[-3] + assert "The dataloaders were not fast enough" in caplog.messages[-2] + assert "in less than 2.00sec" in caplog.messages[-2] + assert "1 out of 2 batches exceeded the load time threshold" in caplog.messages[-1] + assert "Total loading time for the slow batches was 100.00sec" in caplog.messages[-1] + + for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]: + for metric in [MetricType.SECONDS_PER_EPOCH.value, MetricType.EXCESS_BATCH_LOADING_TIME.value]: + assert f"timing/{prefix}{metric}" in logged_metrics From 6674b706b511f44e612a1364bb1d8c223e8bc2ff Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 19 Oct 2021 15:06:56 +0100 Subject: [PATCH 13/41] hyperparams logging --- InnerEye/ML/lightning_loggers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index 1f33ad618..7ae5dfda0 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -167,7 +167,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @rank_zero_only def log_hyperparams(self, params: Any) -> None: - pass + # Convert from Namespace to dictionary + params = self._convert_params(params) + # Convert nested dictionaries to folder-like structure + params = self._flatten_dict(params) + # Convert anything that is not a primitive type to str + params = self._sanitize_params(params) + RUN_CONTEXT.log_table("hyperparams", params) def experiment(self) -> Any: return None From ebf8b2510cb9165b2d57064196f3b899241a4bc8 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 19 Oct 2021 16:25:34 +0100 Subject: [PATCH 14/41] flags --- InnerEye/ML/deep_learning_config.py | 12 +- InnerEye/ML/lightning_base.py | 185 ++++++++++++++++++++++++++++ InnerEye/ML/model_training.py | 8 +- 3 files changed, 200 insertions(+), 5 deletions(-) diff --git a/InnerEye/ML/deep_learning_config.py b/InnerEye/ML/deep_learning_config.py index b7ceb3ce1..90d8a4aef 100644 --- a/InnerEye/ML/deep_learning_config.py +++ b/InnerEye/ML/deep_learning_config.py @@ -218,7 +218,7 @@ class WorkflowParams(param.Parameterized): doc="If set, enable/disable full image inference on test set after ensemble training.") weights_url: List[str] = param.List(default=[], class_=str, doc="If provided, a set of urls from which checkpoints will be downloaded" - "and used for inference.") + "and used for inference.") local_weights_path: List[Path] = param.List(default=[], class_=Path, doc="A list of checkpoints paths to use for inference, " "when the job is running outside Azure.") @@ -590,6 +590,16 @@ class TrainerParams(param.Parameterized): param.Boolean(default=False, doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. " "Setting it to True comes with a performance hit.") + monitor_gpu: bool = param.Boolean(default=False, + doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. " + "This will write GPU utilization metrics every 50 batches by default.") + monitor_loading: bool = param.Boolean(default=True, + doc="If True, add the BatchTimeCallback callback to the Lightning trainer " + "object. This will monitor how long individual batches take to load.") + pl_profiler: Optional[str] = \ + param.String(default=None, + doc="The value to use for the 'profiler' argument for the Lightning trainer. " + "Set to either 'simple', 'advanced', or 'pytorch'") @property def use_gpu(self) -> bool: diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 2baaff4b8..049e4a2b5 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -365,6 +365,191 @@ def get_timers(self, is_training: bool) -> EpochTimers: return self.train_timers if is_training else self.val_timers +class GPUStatsMonitor2(Callback): + r""" + Automatically monitors and logs GPU stats during training and validation stage. ``GPUStatsMonitor`` + is a callback and in order to use it you need to assign a logger in the ``Trainer``. + + Args: + memory_utilization: Set to ``True`` to monitor used, free and percentage of memory + utilization at the start and end of each step. Default: ``True``. + gpu_utilization: Set to ``True`` to monitor percentage of GPU utilization + at the start and end of each step. Default: ``True``. + intra_step_time: Set to ``True`` to monitor the time of each step. Default: ``False``. + inter_step_time: Set to ``True`` to monitor the time between the end of one step + and the start of the next step. Default: ``False``. + fan_speed: Set to ``True`` to monitor percentage of fan speed. Default: ``False``. + temperature: Set to ``True`` to monitor the memory and gpu temperature in degree Celsius. + Default: ``False``. + + Raises: + MisconfigurationException: + If NVIDIA driver is not installed, not running on GPUs, or ``Trainer`` has no logger. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import GPUStatsMonitor + >>> gpu_stats = GPUStatsMonitor() # doctest: +SKIP + >>> trainer = Trainer(callbacks=[gpu_stats]) # doctest: +SKIP + + GPU stats are mainly based on `nvidia-smi --query-gpu` command. The description of the queries is as follows: + + - **fan.speed** – The fan speed value is the percent of maximum speed that the device's fan is currently + intended to run at. It ranges from 0 to 100 %. Note: The reported speed is the intended fan speed. + If the fan is physically blocked and unable to spin, this output will not match the actual fan speed. + Many parts do not report fan speeds because they rely on cooling via fans in the surrounding enclosure. + - **memory.used** – Total memory allocated by active contexts. + - **memory.free** – Total free memory. + - **utilization.gpu** – Percent of time over the past sample period during which one or more kernels was + executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product. + - **utilization.memory** – Percent of time over the past sample period during which global (device) memory was + being read or written. The sample period may be between 1 second and 1/6 second depending on the product. + - **temperature.gpu** – Core GPU temperature, in degrees C. + - **temperature.memory** – HBM memory temperature, in degrees C. + + """ + + def __init__( + self, + memory_utilization: bool = True, + gpu_utilization: bool = True, + intra_step_time: bool = False, + inter_step_time: bool = False, + fan_speed: bool = False, + temperature: bool = False + ): + super().__init__() + + if shutil.which('nvidia-smi') is None: + raise MisconfigurationException( + 'Cannot use GPUStatsMonitor callback because NVIDIA driver is not installed.' + ) + + self._log_stats = AttributeDict({ + 'memory_utilization': memory_utilization, + 'gpu_utilization': gpu_utilization, + 'intra_step_time': intra_step_time, + 'inter_step_time': inter_step_time, + 'fan_speed': fan_speed, + 'temperature': temperature + }) + + def on_train_start(self, trainer, pl_module) -> None: + if not trainer.logger: + raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.') + + if trainer._device_type != DeviceType.GPU: + raise MisconfigurationException( + 'You are using GPUStatsMonitor but are not running on GPU' + f' since gpus attribute in Trainer is set to {trainer.gpus}.' + ) + + self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids)) + + def on_train_epoch_start(self, trainer, pl_module) -> None: + self._snap_intra_step_time = None + self._snap_inter_step_time = None + + @rank_zero_only + def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + if self._log_stats.intra_step_time: + self._snap_intra_step_time = time.time() + + if not self._should_log(trainer): + return + + gpu_stat_keys = self._get_gpu_stat_keys() + gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) + logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) + + if self._log_stats.inter_step_time and self._snap_inter_step_time: + # First log at beginning of second step + logs['batch_time/inter_step (ms)'] = (time.time() - self._snap_inter_step_time) * 1000 + + trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_batch_end( + self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self._log_stats.inter_step_time: + self._snap_inter_step_time = time.time() + + if not self._should_log(trainer): + return + + gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys() + gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) + logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) + + if self._log_stats.intra_step_time and self._snap_intra_step_time: + logs['batch_time/intra_step (ms)'] = (time.time() - self._snap_intra_step_time) * 1000 + + trainer.logger.log_metrics(logs, step=trainer.global_step) + + def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]: + """Run nvidia-smi to get the gpu stats""" + gpu_query = ','.join(queries) + format = 'csv,nounits,noheader' + result = subprocess.run( + [shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'], + encoding="utf-8", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 + check=True + ) + + def _to_float(x: str) -> float: + try: + return float(x) + except ValueError: + return 0. + + stats = result.stdout.strip().split(os.linesep) + stats = [[_to_float(x) for x in s.split(', ')] for s in stats] + return stats + + @staticmethod + def _parse_gpu_stats(gpu_ids: str, stats: List[List[float]], keys: List[Tuple[str, str]]) -> Dict[str, float]: + """Parse the gpu stats into a loggable dict""" + logs = {} + for i, gpu_id in enumerate(gpu_ids.split(',')): + for j, (x, unit) in enumerate(keys): + logs[f'gpu_id: {gpu_id}/{x} ({unit})'] = stats[i][j] + return logs + + def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]: + """Get the GPU stats keys""" + stat_keys = [] + + if self._log_stats.gpu_utilization: + stat_keys.append(('utilization.gpu', '%')) + + if self._log_stats.memory_utilization: + stat_keys.extend([('memory.used', 'MB'), ('memory.free', 'MB'), ('utilization.memory', '%')]) + + return stat_keys + + def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: + """Get the device stats keys""" + stat_keys = [] + + if self._log_stats.fan_speed: + stat_keys.append(('fan.speed', '%')) + + if self._log_stats.temperature: + stat_keys.extend([('temperature.gpu', '°C'), ('temperature.memory', '°C')]) + + return stat_keys + + @staticmethod + def _should_log(trainer) -> bool: + should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) + + return should_log + + class InnerEyeLightning(LightningModule): """ The base class for all InnerEye models for training in PyTorch Lightning. The base class handles all shared diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 157fb29cc..ce3cb81e9 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -159,10 +159,10 @@ def create_lightning_trainer(container: LightningContainer, callbacks = [ last_checkpoint_callback, recovery_checkpoint_callback, - BatchTimeCallback() ] - # TODO: Add a flag for that. - if num_gpus > 0: + if container.monitor_loading: + callbacks.append(BatchTimeCallback()) + if num_gpus > 0 and container.monitor_gpu: logging.info("Adding monitoring for GPU utilization") callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True)) # Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers @@ -197,7 +197,7 @@ def create_lightning_trainer(container: LightningContainer, precision=precision, sync_batchnorm=True, terminate_on_nan=container.detect_anomaly, - profiler="simple", + profiler=container.pl_profiler, resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None, **kwargs) return trainer, storing_logger From b11f6ddcbc4eff8bf4a3e00119841358a891a265 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 19 Oct 2021 16:50:23 +0100 Subject: [PATCH 15/41] submodule --- .gitmodules | 3 +++ hi-ml | 1 + 2 files changed, 4 insertions(+) create mode 160000 hi-ml diff --git a/.gitmodules b/.gitmodules index a2a6b1f53..623bd23c7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "fastMRI"] path = fastMRI url = https://github.com/facebookresearch/fastMRI +[submodule "hi-ml"] + path = hi-ml + url = https://github.com/microsoft/hi-ml diff --git a/hi-ml b/hi-ml new file mode 160000 index 000000000..1cd49695f --- /dev/null +++ b/hi-ml @@ -0,0 +1 @@ +Subproject commit 1cd49695f7b724753c986247075efb666a797804 From 8ac90f228f99f01a06656f99fb29018d9042a228 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 19 Oct 2021 22:17:02 +0100 Subject: [PATCH 16/41] update all usage --- InnerEye/ML/lightning_base.py | 351 +------------------------------ InnerEye/ML/lightning_loggers.py | 179 ---------------- InnerEye/ML/metrics.py | 110 ---------- InnerEye/ML/model_training.py | 5 +- Tests/ML/test_loggers.py | 238 --------------------- hi-ml | 2 +- 6 files changed, 5 insertions(+), 880 deletions(-) delete mode 100644 Tests/ML/test_loggers.py diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 049e4a2b5..5cebafd70 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -25,7 +25,7 @@ WorkflowParams from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.lightning_loggers import StoringLogger -from InnerEye.ML.metrics import EpochTimers, store_epoch_metrics +from InnerEye.ML.metrics import store_epoch_metrics from InnerEye.ML.metrics_dict import DataframeLogger from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.utils import model_util @@ -201,355 +201,6 @@ def load_checkpoint_and_modify(self, path_to_checkpoint: Path) -> Dict[str, Any] return self.config.load_checkpoint_and_modify(path_to_checkpoint=path_to_checkpoint) -class BatchTimeCallback(Callback): - """ - This callback provides tools to measure batch loading time and other diagnostic information. - It prints alerts if the batch loading time is over a threshold for several epochs. - All logging will only happen on rank 0. - - The loading time for a minibatch is estimated by the difference between the start time of a minibatch and the - end time of the previous minibatch. It will consequently also include other operations that happen between the - end of a batch and the start of the next one. For example, computationally expensive callbacks could also - drive up this time. - """ - - def __init__(self) -> None: - # Timers for monitoring data loading time - self.train_timers = EpochTimers() - self.val_timers = EpochTimers() - - def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """ - This is called at the start of training. It stores the model that is being trained, because it will be used - later to log values. - """ - self.module = pl_module - - def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self.train_timers.epoch_start() - - def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self.val_timers.epoch_start() - # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training - # is done for this epoch, even though the on_training_epoch hook has not yet been called. - self.train_timers.epoch_end() - - def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """ - This is a hook called at the end of a training or validation epoch. In here, we can still write - metrics to a logger. - """ - # In validation epochs, mark that it has been completed. Training epochs are marked completed already - # at the start of the validation epoch. - self.val_timers.epoch_end() - # Write all IO stats here, so that the order on the console is Train start, train end, val start, val end. - self.write_and_log_epoch_time(is_training=True) - self.write_and_log_epoch_time(is_training=False) - - def on_train_batch_start(self, - trainer: Trainer, - pl_module: LightningModule, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - self.batch_start(batch_idx=batch_idx, is_training=True) - - def on_validation_batch_start(self, - trainer: Trainer, - pl_module: LightningModule, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - self.batch_start(batch_idx=batch_idx, is_training=False) - - def on_train_batch_end(self, - trainer: Trainer, - pl_module: LightningModule, - outputs: Any, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - self.batch_end(is_training=True) - - def on_validation_batch_end(self, - trainer: Trainer, - pl_module: LightningModule, - outputs: Any, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - self.batch_end(is_training=False) - - def batch_start(self, batch_idx: int, is_training: bool) -> None: - """ - Shared code to keep track of minibatch loading times. This is only done on rank zero. - :param batch_idx: The index of the current minibatch. - :param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from - `on_validation_batch_start`. - """ - timers = self.get_timers(is_training=is_training) - epoch = self.module.current_epoch - message_prefix = f"Epoch {epoch} {'training' if is_training else 'validation'}" - timers.batch_start(batch_index=batch_idx, epoch=epoch, message_prefix=message_prefix) - - def batch_end(self, is_training: bool) -> None: - """ - Shared code to keep track of IO-related metrics when loading a minibatch. - :param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from - `on_validation_batch_end`. - """ - timers = self.get_timers(is_training=is_training) - batch_time = timers.batch_end() - self.log_metric(MetricType.SECONDS_PER_BATCH.value + " avg", - value=batch_time, - is_training=is_training) - self.log_metric(MetricType.SECONDS_PER_BATCH.value + " max", - value=batch_time, - is_training=is_training, - reduce_max=True) - - @rank_zero_only - def write_and_log_epoch_time(self, is_training: bool) -> None: - """ - Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the - time per epoch. - :param is_training: If True, show and log the data for the training epoch. If False, use the data for the - validation epoch. - """ - timers = self.get_timers(is_training=is_training) - epoch_time_seconds = timers.total_epoch_time - status = "training" if is_training else "validation" - logging.info(f"Epoch {self.module.current_epoch} {status} took {epoch_time_seconds:0.2f}sec, of which waiting " - f"for data took {timers.total_load_time:0.2f} sec total.") - if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch: - logging.warning("The dataloaders were not fast enough to always supply the next batch in less than " - f"{timers.max_item_load_time_seconds:0.2f}sec.") - logging.warning( - f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load " - f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.") - # This metric is only written at rank zero, and hence must not be synchronized across workers. If attempted, - # training will get stuck. - self.log_metric(MetricType.SECONDS_PER_EPOCH.value, - value=epoch_time_seconds, - is_training=is_training) - self.log_metric(MetricType.EXCESS_BATCH_LOADING_TIME.value, - value=timers.total_extra_load_time, - is_training=is_training) - - @rank_zero_only - def log_metric(self, name_suffix: str, value: float, is_training: bool, reduce_max: bool = False) -> None: - """ - Write a metric given as a name/value pair to the currently trained module. The full name of the metric is - composed of a fixed prefix "timing/", followed by either "train/" or "val/", and then the given suffix. - :param name_suffix: The suffix for the logged metric name. - :param value: The value to log. - :param is_training: If True, use "train/" in the metric name, otherwise "val/" - :param reduce_max: If True, use torch.max as the aggregation function for the logged values. If False, use - torch.mean - """ - # Metrics are only written at rank 0, and hence must not be synchronized. Trying to synchronize will - # block training. - prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX - self.module.log(name="timing/" + prefix + name_suffix, value=value, - on_step=False, on_epoch=True, sync_dist=False, - reduce_fx=max if reduce_max else torch.mean) - - def get_timers(self, is_training: bool) -> EpochTimers: - """ - Gets the object that holds all metrics and timers, for either the validation or the training epoch. - """ - return self.train_timers if is_training else self.val_timers - - -class GPUStatsMonitor2(Callback): - r""" - Automatically monitors and logs GPU stats during training and validation stage. ``GPUStatsMonitor`` - is a callback and in order to use it you need to assign a logger in the ``Trainer``. - - Args: - memory_utilization: Set to ``True`` to monitor used, free and percentage of memory - utilization at the start and end of each step. Default: ``True``. - gpu_utilization: Set to ``True`` to monitor percentage of GPU utilization - at the start and end of each step. Default: ``True``. - intra_step_time: Set to ``True`` to monitor the time of each step. Default: ``False``. - inter_step_time: Set to ``True`` to monitor the time between the end of one step - and the start of the next step. Default: ``False``. - fan_speed: Set to ``True`` to monitor percentage of fan speed. Default: ``False``. - temperature: Set to ``True`` to monitor the memory and gpu temperature in degree Celsius. - Default: ``False``. - - Raises: - MisconfigurationException: - If NVIDIA driver is not installed, not running on GPUs, or ``Trainer`` has no logger. - - Example:: - - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.callbacks import GPUStatsMonitor - >>> gpu_stats = GPUStatsMonitor() # doctest: +SKIP - >>> trainer = Trainer(callbacks=[gpu_stats]) # doctest: +SKIP - - GPU stats are mainly based on `nvidia-smi --query-gpu` command. The description of the queries is as follows: - - - **fan.speed** – The fan speed value is the percent of maximum speed that the device's fan is currently - intended to run at. It ranges from 0 to 100 %. Note: The reported speed is the intended fan speed. - If the fan is physically blocked and unable to spin, this output will not match the actual fan speed. - Many parts do not report fan speeds because they rely on cooling via fans in the surrounding enclosure. - - **memory.used** – Total memory allocated by active contexts. - - **memory.free** – Total free memory. - - **utilization.gpu** – Percent of time over the past sample period during which one or more kernels was - executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product. - - **utilization.memory** – Percent of time over the past sample period during which global (device) memory was - being read or written. The sample period may be between 1 second and 1/6 second depending on the product. - - **temperature.gpu** – Core GPU temperature, in degrees C. - - **temperature.memory** – HBM memory temperature, in degrees C. - - """ - - def __init__( - self, - memory_utilization: bool = True, - gpu_utilization: bool = True, - intra_step_time: bool = False, - inter_step_time: bool = False, - fan_speed: bool = False, - temperature: bool = False - ): - super().__init__() - - if shutil.which('nvidia-smi') is None: - raise MisconfigurationException( - 'Cannot use GPUStatsMonitor callback because NVIDIA driver is not installed.' - ) - - self._log_stats = AttributeDict({ - 'memory_utilization': memory_utilization, - 'gpu_utilization': gpu_utilization, - 'intra_step_time': intra_step_time, - 'inter_step_time': inter_step_time, - 'fan_speed': fan_speed, - 'temperature': temperature - }) - - def on_train_start(self, trainer, pl_module) -> None: - if not trainer.logger: - raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.') - - if trainer._device_type != DeviceType.GPU: - raise MisconfigurationException( - 'You are using GPUStatsMonitor but are not running on GPU' - f' since gpus attribute in Trainer is set to {trainer.gpus}.' - ) - - self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids)) - - def on_train_epoch_start(self, trainer, pl_module) -> None: - self._snap_intra_step_time = None - self._snap_inter_step_time = None - - @rank_zero_only - def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - if self._log_stats.intra_step_time: - self._snap_intra_step_time = time.time() - - if not self._should_log(trainer): - return - - gpu_stat_keys = self._get_gpu_stat_keys() - gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) - logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) - - if self._log_stats.inter_step_time and self._snap_inter_step_time: - # First log at beginning of second step - logs['batch_time/inter_step (ms)'] = (time.time() - self._snap_inter_step_time) * 1000 - - trainer.logger.log_metrics(logs, step=trainer.global_step) - - @rank_zero_only - def on_train_batch_end( - self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: - if self._log_stats.inter_step_time: - self._snap_inter_step_time = time.time() - - if not self._should_log(trainer): - return - - gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys() - gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) - logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) - - if self._log_stats.intra_step_time and self._snap_intra_step_time: - logs['batch_time/intra_step (ms)'] = (time.time() - self._snap_intra_step_time) * 1000 - - trainer.logger.log_metrics(logs, step=trainer.global_step) - - def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]: - """Run nvidia-smi to get the gpu stats""" - gpu_query = ','.join(queries) - format = 'csv,nounits,noheader' - result = subprocess.run( - [shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'], - encoding="utf-8", - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 - check=True - ) - - def _to_float(x: str) -> float: - try: - return float(x) - except ValueError: - return 0. - - stats = result.stdout.strip().split(os.linesep) - stats = [[_to_float(x) for x in s.split(', ')] for s in stats] - return stats - - @staticmethod - def _parse_gpu_stats(gpu_ids: str, stats: List[List[float]], keys: List[Tuple[str, str]]) -> Dict[str, float]: - """Parse the gpu stats into a loggable dict""" - logs = {} - for i, gpu_id in enumerate(gpu_ids.split(',')): - for j, (x, unit) in enumerate(keys): - logs[f'gpu_id: {gpu_id}/{x} ({unit})'] = stats[i][j] - return logs - - def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]: - """Get the GPU stats keys""" - stat_keys = [] - - if self._log_stats.gpu_utilization: - stat_keys.append(('utilization.gpu', '%')) - - if self._log_stats.memory_utilization: - stat_keys.extend([('memory.used', 'MB'), ('memory.free', 'MB'), ('utilization.memory', '%')]) - - return stat_keys - - def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: - """Get the device stats keys""" - stat_keys = [] - - if self._log_stats.fan_speed: - stat_keys.append(('fan.speed', '%')) - - if self._log_stats.temperature: - stat_keys.extend([('temperature.gpu', '°C'), ('temperature.memory', '°C')]) - - return stat_keys - - @staticmethod - def _should_log(trainer) -> bool: - should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) - - return should_log - - class InnerEyeLightning(LightningModule): """ The base class for all InnerEye models for training in PyTorch Lightning. The base class handles all shared diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index 7ae5dfda0..fb49a9eed 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -2,18 +2,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -import logging -import math -import sys -import time from typing import Any, Dict, Iterable, List, Optional -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import ProgressBarBase from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only -from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context from InnerEye.Common.metrics_constants import TRAIN_PREFIX, VALIDATION_PREFIX from InnerEye.Common.type_annotations import DictStrFloat @@ -144,175 +137,3 @@ def val_results_per_epoch(self) -> List[DictStrFloat]: Gets the full set of validation metrics that the logger stores, as a list of dictionaries per epoch. """ return list(self.to_metrics_dicts(prefix_filter=VALIDATION_PREFIX).values()) - - -class AzureMLLogger(LightningLoggerBase): - """ - A Pytorch Lightning logger that stores metrics in the current AzureML run. If the present run is not - inside AzureML, nothing gets logged. - """ - - def __init__(self) -> None: - super().__init__() - self.is_azureml_run = not is_offline_run_context(RUN_CONTEXT) - - @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - is_epoch_metric = "epoch" in metrics - if self.is_azureml_run: - for key, value in metrics.items(): - # Log all epoch-level metrics without the step information - # All step-level metrics with step - RUN_CONTEXT.log(key, value, step=None if is_epoch_metric else step) - - @rank_zero_only - def log_hyperparams(self, params: Any) -> None: - # Convert from Namespace to dictionary - params = self._convert_params(params) - # Convert nested dictionaries to folder-like structure - params = self._flatten_dict(params) - # Convert anything that is not a primitive type to str - params = self._sanitize_params(params) - RUN_CONTEXT.log_table("hyperparams", params) - - def experiment(self) -> Any: - return None - - def name(self) -> Any: - return "" - - def version(self) -> int: - return 0 - - -PROGRESS_STAGE_TRAIN = "Training" -PROGRESS_STAGE_VAL = "Validation" -PROGRESS_STAGE_TEST = "Testing" -PROGRESS_STAGE_PREDICT = "Prediction" - - -class AzureMLProgressBar(ProgressBarBase): - """ - A PL progress bar that works better in AzureML. It prints timestamps for each message, and works well with a setup - where there is no direct access to the console. - """ - - def __init__(self, - refresh_rate: int = 50, - write_to_logging_info: bool = False - ): - """ - Creates a new AzureML progress bar. - :param refresh_rate: The number of steps after which the progress should be printed out. - :param write_to_logging_info: If True, the progress information will be printed via logging.info. If False, - it will be printed to stdout via print. - """ - super().__init__() - self._refresh_rate = refresh_rate - self._enabled = True - self.stage = "" - self.stage_start_time = 0.0 - self.max_batch_count = 0 - self.progress_print_fn = logging.info if write_to_logging_info else print - self.flush_fn = None if write_to_logging_info else sys.stdout.flush - - @property - def refresh_rate(self) -> int: - return self._refresh_rate - - @property - def is_enabled(self) -> bool: - return self._enabled and self.refresh_rate > 0 - - @property - def is_disabled(self) -> bool: - return not self.is_enabled - - def disable(self) -> None: - self._enabled = False - - def enable(self) -> None: - self._enabled = True - - def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_train_epoch_start(trainer, pl_module) - self.start_stage(PROGRESS_STAGE_TRAIN, self.total_train_batches) - - def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_validation_start(trainer, pl_module) - self.start_stage(PROGRESS_STAGE_VAL, self.total_val_batches) - - def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_test_epoch_start(trainer, pl_module) - self.start_stage(PROGRESS_STAGE_TEST, self.total_test_batches) - - def on_predict_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_predict_epoch_start(trainer, pl_module) - self.start_stage(PROGRESS_STAGE_PREDICT, self.total_predict_batches) - - def start_stage(self, stage: str, max_batch_count: int) -> None: - """ - Sets the information that a new stage of the PL loop is starting. The stage will be available in - self.stage, max_batch_count in self.max_batch_count. The time when this method was called is recorded in - self.stage_start_time - :param stage: The string name of the stage that has just started. - :param max_batch_count: The total number of batches that need to be processed in this stage. - """ - self.stage = stage - self.max_batch_count = max_batch_count - self.stage_start_time = time.time() - - def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, - batch_idx: int, dataloader_idx: int) -> None: - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed=self.train_batch_idx) - - def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, - batch_idx: int, dataloader_idx: int) -> None: - super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed=self.val_batch_idx) - - def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, - batch_idx: int, dataloader_idx: int) -> None: - super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed=self.test_batch_idx) - - def on_predict_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, - batch_idx: int, dataloader_idx: int) -> None: - super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self.update_progress(batches_processed=self.predict_batch_idx) - - def update_progress(self, batches_processed: int): - """ - Writes progress information once the refresh interval is full. - :param batches_processed: The number of batches that have been processed for the current stage. - """ - should_update = self.is_enabled and \ - (batches_processed % self.refresh_rate == 0 or batches_processed == self.max_batch_count) - if not should_update: - return - prefix = f"{self.stage}" - if self.stage in [PROGRESS_STAGE_TRAIN, PROGRESS_STAGE_VAL]: - prefix += f" epoch {self.trainer.current_epoch}" - if self.stage == PROGRESS_STAGE_TRAIN: - prefix += f" (step {self.trainer.lightning_module.global_step})" - prefix += ": " - if math.isinf(self.max_batch_count): - # Can't print out per-cent progress or time estimates if the data is infinite - message = f"{prefix}{batches_processed:4} batches completed" - else: - fraction_completed = batches_processed / self.max_batch_count - percent_completed = int(fraction_completed * 100) - time_elapsed = time.time() - self.stage_start_time - estimated_epoch_duration = time_elapsed / fraction_completed - - def to_minutes(time_sec: float) -> str: - minutes = int(time_sec / 60) - seconds = int(time_sec % 60) - return f"{minutes:02}:{seconds:02}" - - message = (f"{prefix}{batches_processed:4}/{self.max_batch_count} ({percent_completed:3}%) completed. " - f"{to_minutes(time_elapsed)} elapsed, total epoch time ~ {to_minutes(estimated_epoch_duration)}") - self.progress_print_fn(message) - if self.flush_fn: - self.flush_fn() diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index 6394a832e..ff3f2256c 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -77,116 +77,6 @@ def log_metrics(self, run_context: Run = None) -> None: }) -class EpochTimers: - """ - Contains all information necessary to compute the IO metrics: Epoch times, batch times, loading times. - """ - - def __init__(self, - max_item_load_time_seconds: float = 0.5, - max_load_time_warnings: int = 3, - max_load_time_epochs: int = 5 - ) -> None: - """ - Creates a new instance of the class. - :param max_item_load_time_seconds: The maximum expected loading time for a minibatch (given in seconds). - If the loading time exceeds this threshold, a warning is printed. - :param max_load_time_warnings: The maximum number of warnings that will be printed per epoch. - :param max_load_time_epochs: The maximum number of epochs where warnings about the loading time are printed. - """ - self.max_item_load_time_seconds = max_item_load_time_seconds - self.max_load_time_warnings = max_load_time_warnings - self.max_load_time_epochs = max_load_time_epochs - self.load_time_warning_epochs: Set[int] = set() - self.epoch_start_time: float = 0.0 - self.epoch_end_time: float = 0.0 - self.batch_start_time: float = 0.0 - self.num_load_time_warnings: int = 0 - self.num_load_time_exceeded: int = 0 - self.total_extra_load_time: float = 0.0 - self.total_load_time: float = 0.0 - self.num_batches: int = 0 - - def epoch_start(self) -> None: - """ - Resets all timers to the current time, and all counters to 0. The set of epochs for which warnings about - load time were produced will not be reset. - """ - current_time = time.time() - self.epoch_start_time = current_time - self.epoch_end_time = current_time - self.batch_start_time = current_time - self.num_load_time_warnings = 0 - self.num_load_time_exceeded = 0 - self.total_extra_load_time = 0.0 - self.total_load_time = 0.0 - self.num_batches = 0 - - def epoch_end(self) -> None: - """ - Stores the present time in the epoch_end_time field of the object. - """ - self.epoch_end_time = time.time() - - @property - def total_epoch_time(self) -> float: - """ - Gets the time in seconds between epoch start and epoch end. - """ - return self.epoch_end_time - self.epoch_start_time - - @property - def should_warn_in_this_epoch(self) -> bool: - """ - Returns True if warnings about loading time should be printed in the present epoch. Returns False if - this warning has been printed already in more than self.max_load_time_epochs epochs. - :return: - """ - return len(self.load_time_warning_epochs) <= self.max_load_time_epochs - - def batch_start(self, batch_index: int, epoch: int, message_prefix: str) -> float: - """ - Called when a minibatch of data has been loaded. This computes the time it took to load the minibatch - (computed between now and the end of the previous minibatch) - and adds it to the internal bookkeeping. If the minibatch loading time exceeds a threshold, then warnings - are printed (unless too many warnings have been printed already) - :param message_prefix: A prefix string that is added to all diagnostic output. - :param epoch: The index of the current epoch. - :param batch_index: The index of the current minibatch. - :return: The time it took to load the minibatch, in seconds. - """ - item_finish_time = time.time() - item_load_time = item_finish_time - self.batch_start_time - self.total_load_time += item_load_time - # Having slow minibatch loading is OK in the very first batch of the every epoch, where processes - # are spawned. Later, the load time should be zero. - if batch_index == 0: - logging.info(f"{message_prefix}: Loaded the first minibatch of data in {item_load_time:0.2f} sec.") - elif item_load_time > self.max_item_load_time_seconds: - self.load_time_warning_epochs.add(epoch) - self.num_load_time_exceeded += 1 - self.total_extra_load_time += item_load_time - if self.num_load_time_warnings < self.max_load_time_warnings and self.should_warn_in_this_epoch: - logging.warning(f"{message_prefix}: Loading minibatch {batch_index} took {item_load_time:0.2f} sec. " - "This can mean that there are not enough data loader worker processes, or that there " - "is a performance problem in loading. This warning will be printed at most " - f"{self.max_load_time_warnings} times in at most {self.max_load_time_epochs} epochs.") - self.num_load_time_warnings += 1 - return item_load_time - - def batch_end(self) -> float: - """ - Called after a minibatch has been processed (training or validation step completed). Returns the time it took - to process the current batch (including loading). - :return: The time it took to process the current batch, in seconds. - """ - current_time = time.time() - elapsed = current_time - self.batch_start_time - self.batch_start_time = current_time - self.num_batches += 1 - return elapsed - - def surface_distance(seg: sitk.Image, reference_segmentation: sitk.Image) -> float: """ Symmetric surface distances taking into account the image spacing diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index ce3cb81e9..e8b283121 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -19,11 +19,12 @@ from InnerEye.Common.resource_monitor import ResourceMonitor from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, create_best_checkpoint from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER -from InnerEye.ML.lightning_base import BatchTimeCallback, InnerEyeContainer, InnerEyeLightning +from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning from InnerEye.ML.lightning_container import LightningContainer -from InnerEye.ML.lightning_loggers import AzureMLLogger, AzureMLProgressBar, StoringLogger +from InnerEye.ML.lightning_loggers import StoringLogger from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \ get_subject_output_file_per_rank +from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback TEMP_PREFIX = "temp/" diff --git a/Tests/ML/test_loggers.py b/Tests/ML/test_loggers.py deleted file mode 100644 index 57f4c0f0a..000000000 --- a/Tests/ML/test_loggers.py +++ /dev/null @@ -1,238 +0,0 @@ -# ------------------------------------------------------------------------------------------ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. -# ------------------------------------------------------------------------------------------ -import logging -import math -from typing import Callable, Dict, List, Optional -from unittest import mock - -import torch -from _pytest.logging import LogCaptureFixture - -from InnerEye.Common.metrics_constants import MetricType, TRAIN_PREFIX, VALIDATION_PREFIX -from InnerEye.ML.lightning_base import BatchTimeCallback -from InnerEye.ML.lightning_loggers import (AzureMLProgressBar, PROGRESS_STAGE_PREDICT, PROGRESS_STAGE_TEST, - PROGRESS_STAGE_TRAIN, PROGRESS_STAGE_VAL) -from InnerEye.ML.metrics import EpochTimers - - -def test_progress_bar_enable() -> None: - """ - Test the logic for disabling the progress bar. - """ - bar = AzureMLProgressBar(refresh_rate=0) - assert not bar.is_enabled - bar = AzureMLProgressBar(refresh_rate=1) - assert bar.is_enabled - bar.disable() - assert not bar.is_enabled - bar.enable() - assert bar.is_enabled - - -def test_progress_bar() -> None: - bar = AzureMLProgressBar(refresh_rate=1) - mock_trainer = mock.MagicMock(current_epoch=12, - lightning_module=mock.MagicMock(global_step=34), - num_training_batches=10, - emable_validation=False, - num_test_batches=[20], - num_predict_batches=[30]) - bar.on_init_end(mock_trainer) # type: ignore - assert bar.trainer == mock_trainer - messages: List[str] = [] - - def write_message(message: str) -> None: - messages.append(message) - - bar.progress_print_fn = write_message - bar.flush_fn = None - # Messages in training - bar.on_train_epoch_start(None, None) # type: ignore - assert bar.stage == PROGRESS_STAGE_TRAIN - assert bar.train_batch_idx == 0 - assert bar.val_batch_idx == 0 - assert bar.test_batch_idx == 0 - assert bar.predict_batch_idx == 0 - bar.on_train_batch_end(None, None, None, None, None, None) # type: ignore - assert bar.train_batch_idx == 1 - assert "Training epoch 12 (step 34)" in messages[-1] - assert "1/10 ( 10%) completed" in messages[-1] - # When starting the next training epoch, the counters should be reset - bar.on_train_epoch_start(None, None) # type: ignore - assert bar.train_batch_idx == 0 - # Messages in validation - bar.on_validation_start(None, None) # type: ignore - assert bar.stage == PROGRESS_STAGE_VAL - assert bar.max_batch_count == 0 - assert bar.val_batch_idx == 0 - # Number of validation batches is difficult to fake, tweak the field where it is stored in the progress bar - bar.max_batch_count = 5 - bar.on_validation_batch_end(None, None, None, None, None, None) # type: ignore - assert bar.val_batch_idx == 1 - assert "Validation epoch 12: " in messages[-1] - assert "1/5 ( 20%) completed" in messages[-1] - # Messages in testing - bar.on_test_epoch_start(None, None) # type: ignore - assert bar.stage == PROGRESS_STAGE_TEST - test_count = 2 - for _ in range(test_count): - bar.on_test_batch_end(None, None, None, None, None, None) # type: ignore - assert bar.test_batch_idx == test_count - assert "Testing:" in messages[-1] - assert f"{test_count}/20 ( 10%)" in messages[-1] - # Messages in prediction - bar.on_predict_epoch_start(None, None) # type: ignore - assert bar.stage == PROGRESS_STAGE_PREDICT - predict_count = 3 - for _ in range(predict_count): - bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore - assert bar.predict_batch_idx == predict_count - assert "Prediction:" in messages[-1] - assert f"{predict_count}/30 ( 10%)" in messages[-1] - # Test behaviour when a batch count is infinity - bar.max_batch_count = math.inf - bar.on_predict_batch_end(None, None, None, None, None, None) # type: ignore - assert bar.predict_batch_idx == 4 - assert "4 batches completed" in messages[-1] - - -def test_epoch_timers(caplog: LogCaptureFixture) -> None: - """ - Test the class that measures batch and epoch times. - """ - caplog.set_level(logging.INFO) - batch_index = 123 - epoch = 24 - timer = EpochTimers(max_item_load_time_seconds=100) - assert timer.total_load_time == 0.0 - - # First batch should always generate a message - timer.batch_start(batch_index=0, epoch=epoch, message_prefix="prefix") - assert timer.total_load_time > 0.0 - message = caplog.messages[-1] - assert "prefix: Loaded the first minibatch of data in" in message - old_num_batches = timer.num_batches - old_batch_start_time = timer.batch_start_time - timer.batch_end() - assert timer.num_batches == old_num_batches + 1 - assert timer.batch_start_time > old_batch_start_time - - # Second minibatch should only generate a message if above load time threshold. Set threshold very high - old_num_messages = len(caplog.messages) - old_total_load_time = timer.total_load_time - timer.max_item_load_time_seconds = 10.0 - assert timer.num_load_time_exceeded == 0 - timer.batch_start(batch_index=batch_index, epoch=epoch, message_prefix="prefix") - # This should be updated in any case - assert timer.total_load_time > old_total_load_time - # But this batch should not be recognized as having gone over the threshold - assert timer.num_load_time_exceeded == 0 - assert len(timer.load_time_warning_epochs) == 0 - assert len(caplog.messages) == old_num_messages - assert timer.num_load_time_warnings == 0 - - # Third minibatch considered as above threshold: set threshold to 0 for that - old_total_load_time = timer.total_load_time - timer.max_item_load_time_seconds = 0.0 - timer.batch_start(batch_index=batch_index, epoch=epoch, message_prefix="prefix") - # This should be updated in any case - assert timer.total_load_time > old_total_load_time - # Batch should not be recognized as having gone over the threshold - assert timer.num_load_time_exceeded == 1 - assert epoch in timer.load_time_warning_epochs - message = caplog.messages[-1] - assert f"prefix: Loading minibatch {batch_index} took" in message - assert f"This message will be printed at most {timer.max_load_time_warnings} times" - assert timer.num_load_time_warnings > 0 - # Test if the warnings disappear after the max number of warnings - assert timer.should_warn_in_this_epoch - timer.num_load_time_warnings = timer.max_load_time_warnings + 1 - assert not timer.should_warn_in_this_epoch - - # Epoch end time should be stored - assert timer.total_epoch_time == 0.0 - old_epoch_end_time = timer.epoch_end_time - timer.epoch_end() - assert timer.epoch_end_time > old_epoch_end_time - assert timer.total_epoch_time > 0.0 - - # Test the resetting logic - timer.epoch_start() - assert timer.total_load_time == 0.0 - assert timer.num_load_time_warnings == 0 - # The object should keep track of all epochs in which warnings were printed - assert len(timer.load_time_warning_epochs) > 0 - - -def test_batch_time_callback(caplog: LogCaptureFixture) -> None: - """ - Test the callback that measures data loading times. - """ - caplog.set_level(logging.INFO) - callback = BatchTimeCallback() - epoch = 1234 - # This dictionary stores all metrics that are written via module.log - logged_metrics = {} - - def mock_log(name: str, value: float, reduce_fx: Callable, **kwargs: Dict) -> None: - logged_metrics[name] = (value, reduce_fx) - - mock_module = mock.MagicMock(current_epoch=epoch, log=mock_log) - callback.on_fit_start(trainer=None, pl_module=mock_module) # type: ignore - assert callback.module == mock_module - - # Upon epoch start, the timers should be reset. We can check that by looking at epoch_start_time - assert callback.train_timers.epoch_start_time == 0.0 - callback.on_train_epoch_start(None, None) # type: ignore - assert callback.train_timers.epoch_start_time > 0.0 - assert callback.val_timers.epoch_start_time == 0.0 - old_train_epoch_end_time = callback.train_timers.epoch_end_time - callback.on_validation_epoch_start(None, None) # type: ignore - assert callback.val_timers.epoch_start_time > 0.0 - # When calling epoch_start for validation, training epoch should be ended - assert callback.train_timers.epoch_end_time > old_train_epoch_end_time - - # Run 1 training batch - callback.on_train_batch_start(None, None, None, batch_idx=0, dataloader_idx=0) # type: ignore - callback.on_train_batch_end(None, None, None, None, batch_idx=0, dataloader_idx=0) # type: ignore - assert len(logged_metrics) == 2 - # Upon batch end, we should see metrics being logged. Batch level timings should be logged both as averages and max - def check_batch_metrics(train_or_val: str) -> None: - for suffix in [" avg", " max"]: - name = f"timing/{train_or_val}/SecondsPerBatch" + suffix - assert name in logged_metrics - assert logged_metrics[name][1] == max if suffix == " max" else torch.mean - check_batch_metrics("train") - assert caplog.messages[-1].startswith(f"Epoch {epoch} training: Loaded the first") - # Run 2 validation batches - for batch_idx in range(2): - callback.on_validation_batch_start(None, None, None, batch_idx=batch_idx, dataloader_idx=0) # type: ignore - callback.on_validation_batch_end(None, None, None, None, batch_idx=batch_idx, dataloader_idx=0) # type: ignore - assert caplog.messages[-1].startswith(f"Epoch {epoch} validation: Loaded the first") - assert callback.train_timers.num_batches == 1 - assert callback.val_timers.num_batches == 2 - check_batch_metrics("val") - - # Check that the metrics are written at the end of the validation epoch. - # Hack the timers to trigger the warning message for validation only - callback.val_timers.num_load_time_exceeded = 1 - callback.val_timers.total_extra_load_time = 100.00 - callback.val_timers.max_item_load_time_seconds = 2.0 - assert callback.val_timers.should_warn_in_this_epoch - old_val_epoch_end_time = callback.train_timers.epoch_end_time - callback.on_validation_epoch_end(None, None) # type: ignore - assert callback.val_timers.epoch_end_time > old_val_epoch_end_time - assert len(logged_metrics) > 0 - - assert f"Epoch {epoch} training took " in caplog.messages[-4] - assert f"Epoch {epoch} validation took " in caplog.messages[-3] - assert "The dataloaders were not fast enough" in caplog.messages[-2] - assert "in less than 2.00sec" in caplog.messages[-2] - assert "1 out of 2 batches exceeded the load time threshold" in caplog.messages[-1] - assert "Total loading time for the slow batches was 100.00sec" in caplog.messages[-1] - - for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]: - for metric in [MetricType.SECONDS_PER_EPOCH.value, MetricType.EXCESS_BATCH_LOADING_TIME.value]: - assert f"timing/{prefix}{metric}" in logged_metrics diff --git a/hi-ml b/hi-ml index 1cd49695f..b3c48013b 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit 1cd49695f7b724753c986247075efb666a797804 +Subproject commit b3c48013ba067a951d67b47626de66b0063bc385 From 547626e7ca17a0f976a82c5fe84bf1f717cc2aba Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 20 Oct 2021 06:55:51 +0100 Subject: [PATCH 17/41] fix --- InnerEye/Common/fixed_paths.py | 4 ++-- InnerEye/ML/model_training.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/InnerEye/Common/fixed_paths.py b/InnerEye/Common/fixed_paths.py index dc0ee2a56..0e9e09c85 100755 --- a/InnerEye/Common/fixed_paths.py +++ b/InnerEye/Common/fixed_paths.py @@ -107,8 +107,8 @@ def add_submodules_to_path() -> None: innereye_root = repository_root_directory() folders_to_add = [(innereye_root, "InnerEye"), (innereye_root / "fastMRI", "fastmri"), - (innereye_root / "hi-ml" / "hi-ml-azure" / "src", "health"), - (innereye_root / "hi-ml" / "hi-ml" / "src", "health")] + (innereye_root / "hi-ml" / "hi-ml-azure" / "src", "health_azure"), + (innereye_root / "hi-ml" / "hi-ml" / "src", "health_ml")] for (folder, subfolder_that_must_exist) in folders_to_add: if (folder / subfolder_that_must_exist).is_dir(): folder_str = str(folder) diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index e8b283121..f4b15b3dc 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -180,7 +180,9 @@ def create_lightning_trainer(container: LightningContainer, progress_bar_refresh_rate = 50 logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. " f"To change, modify the pl_progress_bar_refresh_rate field of the container.") - callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate, write_to_logging_info=True)) + callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate, + write_to_logging_info=True, + print_timestamp=False)) # Read out additional model-specific args here. # We probably want to keep essential ones like numgpu and logging. trainer = Trainer(default_root_dir=str(container.outputs_folder), From 986fba8b188a2aaa651553566a4ed2292b40b84b Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 20 Oct 2021 14:46:59 +0100 Subject: [PATCH 18/41] cleanup --- InnerEye/ML/lightning_base.py | 2 +- hi-ml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 5cebafd70..687f2bc77 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -9,7 +9,7 @@ import param import torch -from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities import rank_zero_only from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler diff --git a/hi-ml b/hi-ml index b3c48013b..f26ea5e58 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit b3c48013ba067a951d67b47626de66b0063bc385 +Subproject commit f26ea5e58e7ba70df4f9dc1436a9465b2d09f5f2 From a823af08fe7577e0be5ac4eda088a8abe8d6970c Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 2 Nov 2021 16:10:54 +0000 Subject: [PATCH 19/41] callback save and load --- .../datamodules_and_datasets/datamodules.py | 3 ++ .../SSL/lightning_containers/ssl_container.py | 28 ++++++++++--------- .../ssl_image_classifier.py | 5 +--- .../lightning_modules/ssl_online_evaluator.py | 22 ++++++++++++++- InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py | 6 ++-- InnerEye/ML/configs/ssl/CXR_SSL_configs.py | 4 +-- InnerEye/ML/configs/ssl/CovidContainers.py | 2 +- InnerEye/ML/deep_learning_config.py | 17 ++++++++--- InnerEye/ML/model_training.py | 8 ++++-- Tests/SSL/test_ssl_containers.py | 14 ++++++++++ 10 files changed, 78 insertions(+), 31 deletions(-) diff --git a/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py b/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py index 3641a0c98..4a023203b 100644 --- a/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py +++ b/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py @@ -136,6 +136,9 @@ def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, """ The train dataloaders """ + # This code may be superseded in current versions of PL. Using this dictionary syntax will effectively + # use a CombinedLoader(dataloaders, mode="max_size_cycle"), similar to what we need to do explicitly for + # the validation data loader. dataloaders = { SSLDataModuleType.ENCODER: self.encoder_module.train_dataloader(), SSLDataModuleType.LINEAR_HEAD: self.linear_head_module.train_dataloader()} diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index d3f934042..02319fa41 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -75,9 +75,9 @@ class SSLContainer(LightningContainer): "augmentations. Ignored for CIFAR10 example") ssl_training_dataset_name = param.ClassSelector(class_=SSLDatasetName, doc="The name of the dataset") ssl_training_batch_size = param.Integer( - doc="Total training batch size, will be divided across the number of gpus used for training. For example: if " - "you specify ssl_training_batch_size=1600 and use 4 nodes with 4 gpus each (i.e. total of 16 GPUs), " - "the code will provide a per-gpu batch size of 100") + doc="Training batch size per GPU. The effective batch size will be the number of GPUs times this number. " + "For example, if you specify ssl_training_batch_size=100 and use 4 nodes with 4 gpus each, " + "the effective batch size will be 1600.") ssl_training_type = param.ClassSelector(class_=SSLTrainingType, doc="Which algorithm to use for SSL training") ssl_encoder = param.ClassSelector(class_=EncoderName, doc="Which encoder to use for SSL") use_balanced_binary_loss_for_linear_head = param.Boolean(default=False, @@ -100,7 +100,10 @@ class SSLContainer(LightningContainer): def setup(self) -> None: from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer - self.total_num_gpus = self.num_gpus_per_node * self.num_nodes + if self.is_debug_model: + self.pl_limit_train_batches = 1 + self.pl_limit_val_batches = 1 + self.total_num_gpus = self.num_gpus_per_node() * self.num_nodes self._load_config() # If you're using the same data for training and linear head, allow the user to specify the dataset only # once. Or if you are doing just finetuning of linear head, the user should be able to specify dataset via @@ -199,16 +202,17 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio train_transforms, val_transforms = self._get_transforms(datamodule_args.augmentation_params, datamodule_args.dataset_name, is_ssl_encoder_module) - batch_size_per_gpu = datamodule_args.batch_size // self.total_num_gpus if self.total_num_gpus > 0 else \ - datamodule_args.batch_size - logging.info(f"Batch size per gpu: {batch_size_per_gpu}") + batch_multiplier = self.total_num_gpus if self.total_num_gpus > 0 else 1 + effective_batch_size = datamodule_args.batch_size * batch_multiplier + logging.info(f"Batch size per GPU: {datamodule_args.batch_size}") + logging.info(f"Effective batch size on {batch_multiplier} GPUs: {effective_batch_size}") dm = InnerEyeVisionDataModule(dataset_cls=self._SSLDataClassMappings[datamodule_args.dataset_name], return_index=not is_ssl_encoder_module, # index is only needed for linear head train_transforms=train_transforms, val_split=0.1, val_transforms=val_transforms, data_dir=str(datamodule_args.dataset_path), - batch_size=batch_size_per_gpu, + batch_size=datamodule_args.batch_size, num_workers=self.num_workers, seed=self.random_seed, drop_last=self.drop_last) @@ -226,9 +230,9 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode], :param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation pipeline to return. :param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is - applied on and it applies the training transformations also at validation time. Note that if your transformation - does not contain any randomness, the pipeline will return two identical copies. If False, it will return only one - transformation. + applied on and it applies the training transformations also at validation time. Note that if your + transformation does not contain any randomness, the pipeline will return two identical copies. If False, it + will return only one transformation. :return: training transformation pipeline and validation transformation pipeline. """ if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, @@ -269,6 +273,4 @@ def get_trainer_arguments(self) -> Dict[str, Any]: drop_p=0.2, learning_rate=self.learning_rate_linear_head_during_ssl_training) trainer_kwargs: Dict[str, Any] = {"callbacks": self.online_eval} - if self.is_debug_model: - trainer_kwargs.update({"limit_train_batches": 1, "limit_val_batches": 1}) return trainer_kwargs diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py b/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py index 76c8de85f..e890b75ca 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py @@ -64,7 +64,4 @@ def get_data_module(self) -> InnerEyeDataModuleTypes: return self.data_module def get_trainer_arguments(self) -> Dict[str, Any]: - trained_kwargs = {} - if self.is_debug_model: - trained_kwargs.update({"limit_train_batches": 1, "limit_val_batches": 1}) - return trained_kwargs + return {} diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 3eddaff53..00f89b792 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -9,15 +9,18 @@ import torch from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.evaluator import SSLEvaluator -from torchmetrics import Metric from torch import Tensor as T from torch.nn import functional as F +from torchmetrics import Metric from InnerEye.ML.SSL.utils import SSLDataModuleType from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve BatchType = Union[Dict[SSLDataModuleType, Any], Any] +OPTIMIZER_STATE_NAME = "evaluator_optimizer" +EVALUATOR_STATE_NAME = "evaluator_weights" + class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator): def __init__(self, @@ -44,6 +47,23 @@ def __init__(self, if self.num_classes == 2 else [Accuracy05()] self.class_weights = class_weights + def on_save_checkpoint(self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: Dict[str, Any]) -> Dict[str, Any]: + # Each callback gets its own state dictionary, that are fed back in during load + return { + OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), + EVALUATOR_STATE_NAME: pl_module.non_linear_evaluator.state_dict() + } + + def on_load_checkpoint(self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + callback_state: Dict[str, Any]) -> None: + self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME]) + pl_module.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME]) + def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ Initializes modules and moves metrics and class weights to module device diff --git a/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py b/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py index aa2eeb741..8ca2cb58c 100644 --- a/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py +++ b/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py @@ -15,7 +15,7 @@ class CIFAR10SimCLR(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR10, - ssl_training_batch_size=512, + ssl_training_batch_size=128, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.SimCLR, random_seed=1, @@ -32,7 +32,7 @@ class CIFAR10BYOL(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR10, - ssl_training_batch_size=512, + ssl_training_batch_size=128, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, random_seed=1, @@ -49,7 +49,7 @@ class CIFAR10CIFAR100BYOL(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR100, - ssl_training_batch_size=512, + ssl_training_batch_size=64, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, random_seed=1, diff --git a/InnerEye/ML/configs/ssl/CXR_SSL_configs.py b/InnerEye/ML/configs/ssl/CXR_SSL_configs.py index 4a47f1905..24f912bc9 100644 --- a/InnerEye/ML/configs/ssl/CXR_SSL_configs.py +++ b/InnerEye/ML/configs/ssl/CXR_SSL_configs.py @@ -29,7 +29,7 @@ def __init__(self) -> None: random_seed=1, recovery_checkpoint_save_interval=200, num_epochs=1000, - ssl_training_batch_size=1200, + ssl_training_batch_size=75, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, use_balanced_binary_loss_for_linear_head=True, @@ -45,7 +45,7 @@ def __init__(self) -> None: random_seed=1, recovery_checkpoint_save_interval=200, num_epochs=1000, - ssl_training_batch_size=1200, + ssl_training_batch_size=75, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.SimCLR, use_balanced_binary_loss_for_linear_head=True, diff --git a/InnerEye/ML/configs/ssl/CovidContainers.py b/InnerEye/ML/configs/ssl/CovidContainers.py index 2e79d35b8..f091876d5 100644 --- a/InnerEye/ML/configs/ssl/CovidContainers.py +++ b/InnerEye/ML/configs/ssl/CovidContainers.py @@ -23,7 +23,7 @@ def __init__(self, recovery_checkpoint_save_interval=50, recovery_checkpoints_save_last_k=3, num_epochs=500, - ssl_training_batch_size=1200, # This runs with 16 gpus (4 nodes) + ssl_training_batch_size=75, # This runs with 16 gpus (4 nodes) num_workers=12, ssl_encoder=EncoderName.densenet121, ssl_training_type=SSLTrainingType.BYOL, diff --git a/InnerEye/ML/deep_learning_config.py b/InnerEye/ML/deep_learning_config.py index 4a8161045..f088fd9d0 100644 --- a/InnerEye/ML/deep_learning_config.py +++ b/InnerEye/ML/deep_learning_config.py @@ -218,7 +218,7 @@ class WorkflowParams(param.Parameterized): doc="If set, enable/disable full image inference on test set after ensemble training.") weights_url: List[str] = param.List(default=[], class_=str, doc="If provided, a set of urls from which checkpoints will be downloaded" - "and used for inference.") + "and used for inference.") local_weights_path: List[Path] = param.List(default=[], class_=Path, doc="A list of checkpoints paths to use for inference, " "when the job is running outside Azure.") @@ -590,6 +590,14 @@ class TrainerParams(param.Parameterized): param.Boolean(default=False, doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. " "Setting it to True comes with a performance hit.") + pl_limit_train_batches: Optional[int] = \ + param.Integer(default=None, + doc="PyTorch Lightning trainer flag 'limit_train_batches': Limit the training dataset to the " + "given number of batches.") + pl_limit_val_batches: Optional[int] = \ + param.Integer(default=None, + doc="PyTorch Lightning trainer flag 'limit_val_batches': Limit the validation dataset to the " + "given number of batches.") @property def use_gpu(self) -> bool: @@ -602,15 +610,16 @@ def use_gpu(self) -> bool: from InnerEye.ML.utils.ml_util import is_gpu_available return is_gpu_available() - @property def num_gpus_per_node(self) -> int: """ Computes the number of gpus to use for each node: either the number of gpus available on the device or restrict it to max_num_gpu, whichever is smaller. Returns 0 if running on a CPU device. """ import torch - num_gpus = torch.cuda.device_count() if self.use_gpu else 0 - logging.info(f"Number of available GPUs: {num_gpus}") + available_gpus = torch.cuda.device_count() + num_gpus = available_gpus if self.use_gpu else 0 + message_suffix = "" if self.use_gpu else ", but not using them because use_gpu == False" + logging.info(f"Number of available GPUs: {available_gpus}{message_suffix}") if 0 <= self.max_num_gpus < num_gpus: num_gpus = self.max_num_gpus logging.info(f"Restricting the number of GPUs to {num_gpus}") diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 041e952b6..6a1c0fb21 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -63,7 +63,7 @@ class InnerEyeRecoveryCheckpointCallback(ModelCheckpoint): def __init__(self, container: LightningContainer): super().__init__(dirpath=str(container.checkpoint_folder), - monitor="epoch", + monitor="epoch_started", filename=RECOVERY_CHECKPOINT_FILE_NAME + "_{epoch}", period=container.recovery_checkpoint_save_interval, save_top_k=container.recovery_checkpoints_save_last_k, @@ -71,7 +71,7 @@ def __init__(self, container: LightningContainer): save_last=False) def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unused: bool = None) -> None: - pl_module.log(name="epoch", value=trainer.current_epoch) # type: ignore + pl_module.log(name="epoch_started", value=trainer.current_epoch) # type: ignore def create_lightning_trainer(container: LightningContainer, @@ -100,7 +100,7 @@ def create_lightning_trainer(container: LightningContainer, # recovery_checkpoints_save_last_k. recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container) - num_gpus = container.num_gpus_per_node + num_gpus = container.num_gpus_per_node() effective_num_gpus = num_gpus * num_nodes # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory). if effective_num_gpus > 1: @@ -153,6 +153,8 @@ def create_lightning_trainer(container: LightningContainer, accelerator=accelerator, plugins=plugins, max_epochs=container.num_epochs, + limit_train_batches=container.pl_limit_train_batches or 1.0, + limit_val_batches=container.pl_limit_val_batches or 1.0, num_sanity_val_steps=container.pl_num_sanity_val_steps, callbacks=callbacks, logger=loggers, diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 556305e5d..2f455ae38 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -20,6 +20,8 @@ from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier +from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import EVALUATOR_STATE_NAME, OPTIMIZER_STATE_NAME, \ + SSLOnlineEvaluatorInnerEye from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier @@ -84,7 +86,18 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10.value assert not loaded_config.use_balanced_binary_loss_for_linear_head assert isinstance(loaded_config.model.encoder.cnn_model, ResNet) + # Check that the checkpoint contains both the optimizer for the embedding and for the linear head checkpoint_path = loaded_config.outputs_folder / "checkpoints" / "best_checkpoint.ckpt" + checkpoint = torch.load(checkpoint_path) + assert len(checkpoint["optimizer_states"]) == 1 + assert len(checkpoint["lr_schedulers"]) == 1 + assert "callbacks" in checkpoint + assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"] + callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye] + assert OPTIMIZER_STATE_NAME in callback_state + assert EVALUATOR_STATE_NAME in callback_state + + # Now run the actual SSL classifier off the stored checkpoint args = common_test_args + ["--model=SSLClassifierCIFAR", f"--local_ssl_weights_path={checkpoint_path}"] with mock.patch("sys.argv", args): loaded_config, actual_run = default_runner().run() @@ -93,6 +106,7 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: assert loaded_config.model.class_weights is None assert loaded_config.model.num_classes == 10 + @pytest.mark.skipif(is_windows(), reason="Too slow on windows") def test_load_innereye_ssl_container_cifar10_cifar100_resnet_byol() -> None: """ From 17564f1c1612f3601cf5a549e748fb02389f7a2b Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 2 Nov 2021 19:59:00 +0000 Subject: [PATCH 20/41] find_unused --- InnerEye/ML/SSL/lightning_containers/ssl_container.py | 1 + 1 file changed, 1 insertion(+) diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index 02319fa41..3d7a4d82d 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -103,6 +103,7 @@ def setup(self) -> None: if self.is_debug_model: self.pl_limit_train_batches = 1 self.pl_limit_val_batches = 1 + self.pl_find_unused_parameters = True self.total_num_gpus = self.num_gpus_per_node() * self.num_nodes self._load_config() # If you're using the same data for training and linear head, allow the user to specify the dataset only From 6064bc5bed7cffcf8e546e82a89fd756b7727379 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 2 Nov 2021 20:29:01 +0000 Subject: [PATCH 21/41] remove submodule --- hi-ml | 1 - 1 file changed, 1 deletion(-) delete mode 160000 hi-ml diff --git a/hi-ml b/hi-ml deleted file mode 160000 index f26ea5e58..000000000 --- a/hi-ml +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f26ea5e58e7ba70df4f9dc1436a9465b2d09f5f2 From 88ed46ce03b556904c7788c722d30f7bf3c51193 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 2 Nov 2021 20:51:45 +0000 Subject: [PATCH 22/41] storinglogger update --- .gitmodules | 3 -- InnerEye/Common/type_annotations.py | 1 + InnerEye/ML/lightning_loggers.py | 49 ++++++++++++++++++++--------- InnerEye/ML/model_training.py | 12 +++---- InnerEye/ML/run_ml.py | 8 +++-- InnerEye/ML/runner.py | 8 +++-- Tests/ML/test_model_training.py | 43 +++++++++++++++++++++---- docs/environment.md | 5 +-- 8 files changed, 89 insertions(+), 40 deletions(-) diff --git a/.gitmodules b/.gitmodules index 623bd23c7..a2a6b1f53 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ [submodule "fastMRI"] path = fastMRI url = https://github.com/facebookresearch/fastMRI -[submodule "hi-ml"] - path = hi-ml - url = https://github.com/microsoft/hi-ml diff --git a/InnerEye/Common/type_annotations.py b/InnerEye/Common/type_annotations.py index 7c768437f..2d56f675e 100644 --- a/InnerEye/Common/type_annotations.py +++ b/InnerEye/Common/type_annotations.py @@ -15,3 +15,4 @@ TupleFloat9 = Tuple[float, float, float, float, float, float, float, float, float] IntOrTuple3 = Union[int, TupleInt3, Iterable] DictStrFloat = Dict[str, float] +DictStrFloatOrFloatList = Dict[str, Union[float, List[float]]] diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index fb49a9eed..0b90ffa3f 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -19,28 +19,40 @@ class StoringLogger(LightningLoggerBase): def __init__(self) -> None: super().__init__() - self.results: Dict[int, DictStrFloat] = {} + self.results_per_epoch: Dict[int, DictStrFloatOrFloatList] = {} self.hyperparams: Any = None # Fields to store diagnostics for unit testing self.train_diagnostics: List[Any] = [] self.val_diagnostics: List[Any] = [] + self.results_without_epoch: List[DictStrFloat] = [] @rank_zero_only def log_metrics(self, metrics: DictStrFloat, step: Optional[int] = None) -> None: + logging.debug(f"StoringLogger step={step}: {metrics}") epoch_name = "epoch" if epoch_name not in metrics: - raise ValueError("Each of the logged metrics should have an 'epoch' key.") + # Metrics without an "epoch" key are logged during testing, for example + self.results_without_epoch.append(metrics) + return epoch = int(metrics[epoch_name]) del metrics[epoch_name] - if epoch in self.results: - current_results = self.results[epoch] - overlapping_keys = set(metrics.keys()).intersection(current_results.keys()) - if len(overlapping_keys) > 0: - raise ValueError(f"Unable to log metric with same name twice for epoch {epoch}: " - f"{', '.join(overlapping_keys)}") - current_results.update(metrics) + for key, value in metrics.items(): + if isinstance(value, int): + metrics[key] = float(value) + if epoch in self.results_per_epoch: + current_results = self.results_per_epoch[epoch] + for key, value in metrics.items(): + if key in current_results: + logging.debug(f"StoringLogger: appending results for metric {key}") + current_metrics = current_results[key] + if isinstance(current_metrics, list): + current_metrics.append(value) + else: + current_results[key] = [current_metrics, value] + else: + current_results[key] = value else: - self.results[epoch] = metrics + self.results_per_epoch[epoch] = metrics # type: ignore @rank_zero_only def log_hyperparams(self, params: Any) -> None: @@ -60,7 +72,7 @@ def epochs(self) -> Iterable[int]: """ Gets the epochs for which the present object holds any results. """ - return self.results.keys() + return self.results_per_epoch.keys() def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat: """ @@ -72,7 +84,7 @@ def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat have a name starting with `prefix`, and strip off the prefix. :return: A metrics dictionary. """ - epoch_results = self.results.get(epoch, None) + epoch_results = self.results_per_epoch.get(epoch, None) if epoch_results is None: raise KeyError(f"No results are stored for epoch {epoch}") filtered = {} @@ -82,8 +94,8 @@ def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat # filter is supplied and really matches the metric name if (not prefix_filter) or key.startswith(prefix_filter): stripped_key = key[len(prefix_filter):] - filtered[stripped_key] = value - return filtered + filtered[stripped_key] = value # type: ignore + return filtered # type: ignore def to_metrics_dicts(self, prefix_filter: str = "") -> Dict[int, DictStrFloat]: """ @@ -106,7 +118,14 @@ def get_metric(self, is_training: bool, metric_type: str) -> List[float]: :return: A list of floating point numbers, with one entry per entry in the the training or validation results. """ full_metric_name = (TRAIN_PREFIX if is_training else VALIDATION_PREFIX) + metric_type - return [self.results[epoch][full_metric_name] for epoch in self.epochs] + result = [] + for epoch in self.epochs: + value = self.results_per_epoch[epoch][full_metric_name] + if not isinstance(value, float): + raise ValueError(f"Expected a floating point value for metric {full_metric_name}, but got: " + f"{value}") + result.append(value) + return result def get_train_metric(self, metric_type: str) -> List[float]: """ diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 307aaac03..ec920c730 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -79,7 +79,7 @@ def create_lightning_trainer(container: LightningContainer, resume_from_checkpoint: Optional[Path] = None, num_nodes: int = 1, **kwargs: Dict[str, Any]) -> \ - Tuple[Trainer, Optional[StoringLogger]]: + Tuple[Trainer, StoringLogger]: """ Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second @@ -105,12 +105,8 @@ def create_lightning_trainer(container: LightningContainer, logging.info(f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'") tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="") loggers = [tensorboard_logger, AzureMLLogger()] - storing_logger: Optional[StoringLogger] - if isinstance(container, InnerEyeContainer): - storing_logger = StoringLogger() - loggers.append(storing_logger) - else: - storing_logger = None + storing_logger = StoringLogger() + loggers.append(storing_logger) # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag. precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32 # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark @@ -200,7 +196,7 @@ def start_resource_monitor(config: LightningContainer) -> ResourceMonitor: def model_train(checkpoint_path: Optional[Path], container: LightningContainer, - num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]: + num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]: """ The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 665e48b32..1028d3ecf 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -188,6 +188,7 @@ def __init__(self, self.post_cross_validation_hook = post_cross_validation_hook self.model_deployment_hook = model_deployment_hook self.output_subfolder = output_subfolder + self.storing_logger: Optional[StoringLogger] = None self._has_setup_run = False def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None: @@ -384,9 +385,10 @@ def run(self) -> None: # train a new model if required if self.azure_config.train: with logging_section("Model training"): - model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(), - container=self.container, - num_nodes=self.azure_config.num_nodes) + _, storing_logger = model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(), + container=self.container, + num_nodes=self.azure_config.num_nodes) + self.storing_logger = storing_logger # Since we have trained the model further, let the checkpoint_handler object know so it can handle # checkpoints correctly. self.checkpoint_handler.additional_training_done() diff --git a/InnerEye/ML/runner.py b/InnerEye/ML/runner.py index 28dd3085f..5672a2022 100755 --- a/InnerEye/ML/runner.py +++ b/InnerEye/ML/runner.py @@ -134,6 +134,8 @@ def __init__(self, self.model_config: Optional[DeepLearningConfig] = None self.azure_config: AzureConfig = AzureConfig() self.lightning_container: LightningContainer = None # type: ignore + # This field stores the MLRunner object that has been created in the most recent call to the run() method. + self.ml_runner: Optional[MLRunner] = None def parse_and_load_model(self) -> ParserResult: """ @@ -379,9 +381,9 @@ def run_in_situ(self, azure_run_info: AzureRunInfo) -> None: # Set environment variables for multi-node training if needed. This function will terminate early # if it detects that it is not in a multi-node environment. set_environment_variables_for_multi_node() - ml_runner = self.create_ml_runner() - ml_runner.setup(azure_run_info) - ml_runner.run() + self.ml_runner = self.create_ml_runner() + self.ml_runner.setup(azure_run_info) + self.ml_runner.run() def create_ml_runner(self) -> MLRunner: """ diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index 16c945f95..c2f0db7fe 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -2,17 +2,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -import logging -import os -import shutil -from pathlib import Path -from typing import Any, Dict, List - import h5py +import logging import numpy as np +import os import pandas as pd import pytest +import shutil +from pathlib import Path from torch.utils.data import DataLoader +from typing import Any, Dict, List from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, is_windows, logging_to_stdout @@ -368,3 +367,35 @@ def test_aggregate_and_create_subject_metrics_file(test_output_dirs: OutputFolde written_lines = pd.read_csv(outputs_folder / mode / SUBJECT_METRICS_FILE_NAME) expected_lines = pd.read_csv(outputs_folder / mode / "expected_metrics.csv") assert written_lines.equals(expected_lines) + + +def test_storing_logger() -> None: + """ + Test if the StoringLogger can correctly handle multiple metrics of the same name logged per epoch. + """ + logger = StoringLogger() + key1 = "key" + key2 = "key2" + value1 = 3.14 + value2 = 2.71 + value3 = 100.0 + assert value1 != value2 + epoch = 1 + # Add metrics in the same epoch in two calls, so that we test both the cases where the epoch is already present, + # and where not + logger.log_metrics({"epoch": 1, key1: value1}) + logger.log_metrics({"epoch": 1, key2: value2}) + # All results for epoch 1 should be collated into a single dictionary + assert logger.extract_by_prefix(epoch=epoch) == {key1: value1, key2: value2} + # When updating a metric that already exists, the result should not be a float anymore but a list. + logger.log_metrics({"epoch": epoch, key1: value3}) + assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3], key2: value2} + # Add more metrics for key1, so that we also test the case that the results are already a list + logger.log_metrics({"epoch": epoch, key1: value3}) + assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3, value3], key2: value2} + # Add metrics that don't have an epoch key: This happens for example during testing with trainer.test + other_metrics1 = {"foo": 1.0} + other_metrics2 = {"foo": 2.0} + logger.log_metrics(other_metrics1) + logger.log_metrics(other_metrics2) + assert logger.results_without_epoch == [other_metrics1, other_metrics2] diff --git a/docs/environment.md b/docs/environment.md index 28ca6faa5..d12d94288 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -202,8 +202,9 @@ and that costs 20min per run. * There is already code in `InnerEye.Common.fixed_paths.add_submodules_to_path` that will pick up the submodules and add them to `sys.path`. -Once you are done testing your changes, remove the entry for `hi-ml` from `.gitmodules` and execute these steps -from the repository root: +Once you are done testing your changes: +* Remove the entry for `hi-ml` from `.gitmodules` +* Execute these steps from the repository root: ```shell git submodule deinit -f hi-ml rmdir hi-ml From 857026c1123c425e0f23aa5a031db2986bbafad8 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 2 Nov 2021 21:45:28 +0000 Subject: [PATCH 23/41] head_batchsize --- InnerEye/ML/SSL/lightning_containers/ssl_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index 3d7a4d82d..484999e32 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -92,7 +92,7 @@ class SSLContainer(LightningContainer): "augmentations") linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName, doc="Name of the dataset to use for the linear head training") - linear_head_batch_size = param.Integer(default=256, doc="Batch size for linear head tuning") + linear_head_batch_size = param.Integer(default=16, doc="Batch size for linear head tuning") learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4, doc="Learning rate for linear head training during " "SSL training.") From a71a476c0ed5e13e063bf51d83292fd85ff4ea06 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 2 Nov 2021 22:04:34 +0000 Subject: [PATCH 24/41] using submodule --- .gitmodules | 3 +++ .idea/InnerEye-DeepLearning.iml | 4 +++- InnerEye/Common/type_annotations.py | 2 +- InnerEye/ML/lightning_base.py | 2 +- InnerEye/ML/lightning_loggers.py | 1 + environment.yml | 1 - hi-ml | 1 + 7 files changed, 10 insertions(+), 4 deletions(-) create mode 160000 hi-ml diff --git a/.gitmodules b/.gitmodules index a2a6b1f53..623bd23c7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "fastMRI"] path = fastMRI url = https://github.com/facebookresearch/fastMRI +[submodule "hi-ml"] + path = hi-ml + url = https://github.com/microsoft/hi-ml diff --git a/.idea/InnerEye-DeepLearning.iml b/.idea/InnerEye-DeepLearning.iml index b44301926..aeaeafd64 100644 --- a/.idea/InnerEye-DeepLearning.iml +++ b/.idea/InnerEye-DeepLearning.iml @@ -4,8 +4,10 @@ + + - + diff --git a/InnerEye/Common/type_annotations.py b/InnerEye/Common/type_annotations.py index 2d56f675e..775161029 100644 --- a/InnerEye/Common/type_annotations.py +++ b/InnerEye/Common/type_annotations.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ from pathlib import Path -from typing import Dict, Iterable, Optional, Tuple, TypeVar, Union +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union T = TypeVar('T') PathOrString = Union[Path, str] diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 687f2bc77..224a34c9b 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -302,7 +302,7 @@ def read_epoch_results_from_logger_and_store(self, epoch: int) -> None: Training and Validation metrics. """ if epoch >= 0: - if epoch in self.storing_logger.results: + if epoch in self.storing_logger.results_per_epoch: for is_training, prefix in [(True, TRAIN_PREFIX), (False, VALIDATION_PREFIX)]: metrics = self.storing_logger.extract_by_prefix(epoch, prefix) self.store_epoch_results(metrics, epoch, is_training) diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index 0b90ffa3f..f07496f7f 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import logging from typing import Any, Dict, Iterable, List, Optional from pytorch_lightning.loggers import LightningLoggerBase diff --git a/environment.yml b/environment.yml index 92d673d46..a312455a0 100644 --- a/environment.yml +++ b/environment.yml @@ -23,7 +23,6 @@ dependencies: - gitpython==3.1.7 - gputil==1.4.0 - h5py==2.10.0 - - hi-ml-azure>=0.1.9 - InnerEye-DICOM-RT==1.0.1 - joblib==0.16.0 - jupyter==1.0.0 diff --git a/hi-ml b/hi-ml new file mode 160000 index 000000000..e89e7d768 --- /dev/null +++ b/hi-ml @@ -0,0 +1 @@ +Subproject commit e89e7d768b2b5149eb1b898ae8ee5ada01184d46 From 69fe24753499a36b7218a32b87be54dbdbb45ca2 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Tue, 2 Nov 2021 22:07:03 +0000 Subject: [PATCH 25/41] import fix --- InnerEye/ML/lightning_loggers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index f07496f7f..b4481b60c 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -9,7 +9,7 @@ from pytorch_lightning.utilities import rank_zero_only from InnerEye.Common.metrics_constants import TRAIN_PREFIX, VALIDATION_PREFIX -from InnerEye.Common.type_annotations import DictStrFloat +from InnerEye.Common.type_annotations import DictStrFloat, DictStrFloatOrFloatList class StoringLogger(LightningLoggerBase): From 3ef5d5dd871b3bae31875b5de63cebb9c0d08869 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 09:13:01 +0000 Subject: [PATCH 26/41] log_on_epoch --- InnerEye/ML/lightning_base.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 224a34c9b..ee988a899 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -3,7 +3,6 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import logging -import numbers from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -35,6 +34,7 @@ from InnerEye.ML.utils.ml_util import RandomStateSnapshot, set_random_seed, validate_dataset_paths from InnerEye.ML.utils.model_util import generate_and_print_model_summary from InnerEye.ML.visualizers.patch_sampling import visualize_random_crops_for_dataset +from health_ml.utils import log_on_epoch class TrainAndValDataLightning(LightningDataModule): @@ -324,20 +324,19 @@ def log_on_epoch(self, :param name: The name of the metric to log :param value: The value of the metric. This can be a tensor, floating point value, or a Metric class. :param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix. - :param reduce_fx: The reduce function to apply after synchronizing the tensors across GPUs. + :param reduce_fx: The reduce function to apply to step values. Default: torch.mean :param sync_dist_op: The reduce operation to use when synchronizing the tensors across GPUs. This must be a value recognized by sync_ddp: Either 'None' to use 'sum' as aggregate, or 'mean' or 'avg' """ metric_name = name if isinstance(name, str) else name.value - if isinstance(value, numbers.Number): - value = torch.tensor(value, dtype=torch.float, device=self.device) prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX sync_dist = self.use_sync_dist if sync_dist_override is None else sync_dist_override - self.log(prefix + metric_name, value, - sync_dist=sync_dist, - on_step=False, on_epoch=True, - reduce_fx=reduce_fx, - sync_dist_op=sync_dist_op) + log_on_epoch(self, + name=prefix + metric_name, + value=value, + sync_dist=sync_dist, + reduce_fx=reduce_fx, + sync_dist_op=sync_dist_op) def store_epoch_results(self, metrics: DictStrFloat, epoch: int, is_training: bool) -> None: """ From 40b6d07fae6145bb852cd7e76fa2b72d4006aa41 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 13:53:38 +0000 Subject: [PATCH 27/41] cleanup of metrics --- .flake8 | 2 +- InnerEye/Common/metrics_constants.py | 9 ++----- InnerEye/ML/lightning_base.py | 4 ++-- InnerEye/ML/metrics.py | 14 ++++------- InnerEye/ML/metrics_dict.py | 2 -- InnerEye/ML/model_training.py | 8 +++---- InnerEye/ML/run_ml.py | 1 + Tests/Common/test_metrics_dict.py | 10 ++++---- Tests/ML/models/test_scalar_model.py | 5 ---- Tests/ML/test_model_training.py | 36 ++++++++++++++-------------- 10 files changed, 38 insertions(+), 53 deletions(-) diff --git a/.flake8 b/.flake8 index 290d0ea9a..0097ca829 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ ignore = E226,E302,E41,W391, E701, W291, E722, W503, E128, E126, E127, E731, E401 max-line-length = 160 max-complexity = 25 -exclude = fastMRI/ test_outputs/ +exclude = fastMRI/ test_outputs/ hi-ml/ diff --git a/InnerEye/Common/metrics_constants.py b/InnerEye/Common/metrics_constants.py index 4961e98fe..75d2ed366 100644 --- a/InnerEye/Common/metrics_constants.py +++ b/InnerEye/Common/metrics_constants.py @@ -7,6 +7,8 @@ # String prefixes when writing training or validation set metrics to a logger from typing import Union +from health_ml.utils import BatchTimeCallback + TRAIN_PREFIX = "train/" VALIDATION_PREFIX = "val/" @@ -45,8 +47,6 @@ class LoggingColumns(Enum): AccuracyAtThreshold05 = "accuracy_at_threshold_05" Loss = "loss" CrossEntropy = "cross_entropy" - SecondsPerEpoch = "seconds_per_epoch" - SecondsPerBatch = "seconds_per_batch" AreaUnderRocCurve = "area_under_roc_curve" AreaUnderPRCurve = "area_under_pr_curve" CrossValidationSplitIndex = "cross_validation_split_index" @@ -100,9 +100,6 @@ class MetricType(Enum): EXPLAINED_VAR = "ExplainedVariance" # Common metrics - SECONDS_PER_BATCH = "SecondsPerBatch" - SECONDS_PER_EPOCH = "SecondsPerEpoch" - EXCESS_BATCH_LOADING_TIME = "TotalExcessLoadingTimeSeconds" SUBJECT_COUNT = "SubjectCount" LEARNING_RATE = "LearningRate" @@ -115,8 +112,6 @@ class MetricType(Enum): MetricType.LOSS.value: LoggingColumns.Loss, MetricType.ACCURACY_AT_THRESHOLD_05.value: LoggingColumns.AccuracyAtThreshold05, MetricType.CROSS_ENTROPY.value: LoggingColumns.CrossEntropy, - MetricType.SECONDS_PER_BATCH.value: LoggingColumns.SecondsPerBatch, - MetricType.SECONDS_PER_EPOCH.value: LoggingColumns.SecondsPerEpoch, MetricType.AREA_UNDER_ROC_CURVE.value: LoggingColumns.AreaUnderRocCurve, MetricType.AREA_UNDER_PR_CURVE.value: LoggingColumns.AreaUnderPRCurve, MetricType.SUBJECT_COUNT.value: LoggingColumns.SubjectCount, diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index ee988a899..b67810b58 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -262,7 +262,7 @@ def training_epoch_end(self, outputs: List[Any]) -> None: # Write out all the metrics that have been accumulated in the StoringLogger in the previous epoch. # Metrics for the very last epoch are written in on_train_end self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch - 1) - self.training_or_validation_epoch_end(is_training=True) + self.training_or_validation_epoch_end(is_training=True) # type: ignore def on_validation_epoch_start(self) -> None: """ @@ -285,7 +285,7 @@ def validation_epoch_end(self, outputs: List[Any]) -> None: # reset the random state for training, so that we get continue from where we were before the validation step. assert self.random_state is not None self.random_state.restore_random_state() - self.training_or_validation_epoch_end(is_training=False) + self.training_or_validation_epoch_end(is_training=False) # type: ignore @rank_zero_only def on_train_end(self) -> None: diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index ff3f2256c..44217f170 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -6,16 +6,15 @@ import logging import math -import time -from dataclasses import dataclass, field -from typing import List, Optional, Sequence, Set +from dataclasses import dataclass +from typing import List, Optional, Sequence import SimpleITK as sitk import numpy as np -from numpy.core.numeric import NaN import torch import torch.nn.functional as F from azureml.core import Run +from numpy.core.numeric import NaN from InnerEye.Azure.azure_util import get_run_context_or_default from InnerEye.Common.metrics_constants import LoggingColumns, MetricType @@ -27,8 +26,8 @@ from InnerEye.ML.scalar_config import ScalarLoss from InnerEye.ML.utils.image_util import binaries_from_multi_label_array, is_binary_array from InnerEye.ML.utils.io_util import reverse_tuple_float3 -from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, mean_absolute_error, - r2_score, is_missing_ground_truth) +from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, is_missing_ground_truth, + mean_absolute_error, r2_score) from InnerEye.ML.utils.ml_util import check_size_matches from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels @@ -272,9 +271,6 @@ def store_epoch_metrics(metrics: DictStrFloat, hue_suffix = "/" + tokens[1] else: raise ValueError(f"Expected key to have format 'metric_name[/optional_suffix_for_hue]', got {key}") - - if metric_name == MetricType.SECONDS_PER_BATCH.value or metric_name == MetricType.SECONDS_PER_EPOCH.value: - continue if metric_name in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys(): logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[metric_name].value + hue_suffix] = value else: diff --git a/InnerEye/ML/metrics_dict.py b/InnerEye/ML/metrics_dict.py index fd2dc8fad..405fbc600 100644 --- a/InnerEye/ML/metrics_dict.py +++ b/InnerEye/ML/metrics_dict.py @@ -818,8 +818,6 @@ def flush(self, log_info: bool = False) -> None: df = pd.DataFrame.from_records(self.records, columns=columns) special_formatting = { MetricType.LEARNING_RATE.value: ".6e", - MetricType.SECONDS_PER_EPOCH.value: ".2f", - MetricType.SECONDS_PER_BATCH.value: ".2f", } for column, column_format in special_formatting.items(): if column in df: diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index ec920c730..efd7cf164 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -6,10 +6,9 @@ import os import sys from pathlib import Path -from typing import Any, Dict, Optional, Tuple, TypeVar +from typing import Any, Dict, List, Optional, Tuple, TypeVar -from health_azure.utils import is_global_rank_zero, is_local_rank_zero -from pytorch_lightning import LightningModule, Trainer, seed_everything +from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.plugins import DDPPlugin @@ -25,6 +24,7 @@ from InnerEye.ML.lightning_loggers import StoringLogger from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \ get_subject_output_file_per_rank +from health_azure.utils import is_global_rank_zero, is_local_rank_zero from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback TEMP_PREFIX = "temp/" @@ -129,7 +129,7 @@ def create_lightning_trainer(container: LightningContainer, # Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last # recovery_checkpoints_save_last_k. recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container) - callbacks = [ + callbacks: List[Callback] = [ last_checkpoint_callback, recovery_checkpoint_callback, ] diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 1028d3ecf..250cfe661 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -43,6 +43,7 @@ FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint from InnerEye.ML.lightning_base import InnerEyeContainer from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer +from InnerEye.ML.lightning_loggers import StoringLogger from InnerEye.ML.metrics import InferenceMetrics, InferenceMetricsForSegmentation from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.model_inference_config import ModelInferenceConfig diff --git a/Tests/Common/test_metrics_dict.py b/Tests/Common/test_metrics_dict.py index 98e73f0e1..049ad1681 100644 --- a/Tests/Common/test_metrics_dict.py +++ b/Tests/Common/test_metrics_dict.py @@ -531,8 +531,8 @@ def test_add_foreground_dice() -> None: def test_dataframe_logger() -> None: fixed_columns = {"cross_validation_split_index": 1} records = [ - {"bar": math.pi, MetricType.LEARNING_RATE.value: 1e-5, MetricType.SECONDS_PER_EPOCH.value: 123.123456}, - {"bar": math.pi, MetricType.LEARNING_RATE.value: 1, MetricType.SECONDS_PER_EPOCH.value: 123.123456}, + {"bar": math.pi, MetricType.LEARNING_RATE.value: 1e-5}, + {"bar": math.pi, MetricType.LEARNING_RATE.value: 1}, ] out_buffer = StringIO() df = DataframeLogger(csv_path=out_buffer, fixed_columns=fixed_columns) @@ -540,6 +540,6 @@ def test_dataframe_logger() -> None: df.add_record(r) df.flush() assert out_buffer.getvalue().splitlines() == [ - 'bar,LearningRate,SecondsPerEpoch,cross_validation_split_index', - '3.141593,1.000000e-05,123.12,1', - '3.141593,1.000000e+00,123.12,1'] + 'bar,LearningRate,cross_validation_split_index', + '3.141593,1.000000e-05,1', + '3.141593,1.000000e+00,1'] diff --git a/Tests/ML/models/test_scalar_model.py b/Tests/ML/models/test_scalar_model.py index a39bf399d..8b2d0c8f1 100644 --- a/Tests/ML/models/test_scalar_model.py +++ b/Tests/ML/models/test_scalar_model.py @@ -69,8 +69,6 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol val_results_per_epoch = model_training_result.val_results_per_epoch() assert len(train_results_per_epoch) == config.num_epochs assert len(val_results_per_epoch) == config.num_epochs - assert len(train_results_per_epoch[0]) >= 11 - assert len(val_results_per_epoch[0]) >= 11 for metric in [MetricType.ACCURACY_AT_THRESHOLD_05, MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD, @@ -78,8 +76,6 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol MetricType.AREA_UNDER_ROC_CURVE, MetricType.CROSS_ENTROPY, MetricType.LOSS, - MetricType.SECONDS_PER_BATCH, - MetricType.SECONDS_PER_EPOCH, MetricType.SUBJECT_COUNT]: assert metric.value in train_results_per_epoch[0], f"{metric.value} not in training" assert metric.value in val_results_per_epoch[0], f"{metric.value} not in validation" @@ -193,7 +189,6 @@ def test_train_classification_multilabel_model(test_output_dirs: OutputFolderFor assert f'{metric.value}/{class_name}' in train_results_per_epoch[0], f"{metric.value} not in training" assert f'{metric.value}/{class_name}' in val_results_per_epoch[0], f"{metric.value} not in validation" for metric in [MetricType.LOSS, - MetricType.SECONDS_PER_EPOCH, MetricType.SUBJECT_COUNT]: assert metric.value in train_results_per_epoch[0], f"{metric.value} not in training" assert metric.value in val_results_per_epoch[0], f"{metric.value} not in validation" diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index c2f0db7fe..c6423da9c 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -2,17 +2,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -import h5py import logging -import numpy as np import os -import pandas as pd -import pytest import shutil from pathlib import Path -from torch.utils.data import DataLoader from typing import Any, Dict, List +import h5py +import numpy as np +import pandas as pd +import pytest +from torch.utils.data import DataLoader + from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, is_windows, logging_to_stdout from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path @@ -113,15 +114,20 @@ def _mean_list(lists: List[List[float]]) -> List[float]: model_training_result, _ = model_train_unittest(train_config, dirs=output_dirs) assert isinstance(model_training_result, StoringLogger) - for epoch, epoch_results in model_training_result.results.items(): + # Check that all metrics from the BatchTimeCallback are present + for epoch, epoch_results in model_training_result.results_per_epoch.items(): for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]: - for metric_type in [MetricType.SECONDS_PER_EPOCH.value, - MetricType.SECONDS_PER_BATCH.value, - MetricType.EXCESS_BATCH_LOADING_TIME.value, - MetricType.SECONDS_PER_BATCH.value + " max"]: - expected = "timing/" + prefix + metric_type + for metric_type in [BatchTimeCallback.EPOCH_TIME, + BatchTimeCallback.BATCH_TIME + " avg", + BatchTimeCallback.BATCH_TIME + " max", + BatchTimeCallback.EXCESS_LOADING_TIME]: + expected = BatchTimeCallback.METRICS_PREFIX + prefix + metric_type assert expected in epoch_results, f"Expected {expected} in results for epoch {epoch}" - assert epoch_results[expected] > 0.0, "Time should be > 0" + # Excess loading time can be zero because that only measure batches over the threshold + if metric_type != BatchTimeCallback.EXCESS_LOADING_TIME: + value = epoch_results[expected] + assert isinstance(value, float) + assert value > 0.0, f"Time for {expected} should be > 0" actual_train_losses = model_training_result.get_train_metric(MetricType.LOSS.value) actual_val_losses = model_training_result.get_val_metric(MetricType.LOSS.value) @@ -200,12 +206,6 @@ def assert_all_close(metric: str, expected: List[float], **kwargs: Any) -> None: assert train_config.show_patch_sampling > 0 assert len(list(sampling_folder.rglob("*.png"))) == 3 * train_config.show_patch_sampling - # Time per epoch: Test that we have all these times logged. - model_training_result.get_train_metric(MetricType.SECONDS_PER_EPOCH.value) - model_training_result.get_val_metric(MetricType.SECONDS_PER_EPOCH.value) - model_training_result.get_val_metric(MetricType.SECONDS_PER_BATCH.value) - model_training_result.get_train_metric(MetricType.SECONDS_PER_BATCH.value) - # # Test for saving of example images assert train_config.example_images_folder.is_dir() if train_config.store_dataset_sample else True example_files = list(train_config.example_images_folder.rglob("*.*")) From 80805446e4db119e7694060b1a034b28571f40a0 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 13:57:10 +0000 Subject: [PATCH 28/41] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f62760fc0..d6aab2fe2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,8 @@ in inference-only runs when using lightning containers. ### Removed +- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Removing the monitoring of batch loading time, + use the `BatchTimeCallback` from `hi-ml` instead - ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline. - ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameters `local_weights_path` and `weights_url` can no longer be used to initialize a training run, only inference runs. From 8cbafe7943a1a5ca74b1533392b28fc8e58467ba Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 14:03:00 +0000 Subject: [PATCH 29/41] removing submodule --- .gitmodules | 3 --- docs/environment.md | 2 +- environment.yml | 2 ++ hi-ml | 1 - 4 files changed, 3 insertions(+), 5 deletions(-) delete mode 160000 hi-ml diff --git a/.gitmodules b/.gitmodules index 623bd23c7..a2a6b1f53 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ [submodule "fastMRI"] path = fastMRI url = https://github.com/facebookresearch/fastMRI -[submodule "hi-ml"] - path = hi-ml - url = https://github.com/microsoft/hi-ml diff --git a/docs/environment.md b/docs/environment.md index d12d94288..36d01666e 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -207,7 +207,7 @@ Once you are done testing your changes: * Execute these steps from the repository root: ```shell git submodule deinit -f hi-ml -rmdir hi-ml +rm -rf hi-ml rm -rf .git/modules/hi-ml ``` diff --git a/environment.yml b/environment.yml index a312455a0..0f831cfa6 100644 --- a/environment.yml +++ b/environment.yml @@ -23,6 +23,8 @@ dependencies: - gitpython==3.1.7 - gputil==1.4.0 - h5py==2.10.0 + - hi-ml==0.1.10 + - hi-ml-azure==0.1.10 - InnerEye-DICOM-RT==1.0.1 - joblib==0.16.0 - jupyter==1.0.0 diff --git a/hi-ml b/hi-ml deleted file mode 160000 index e89e7d768..000000000 --- a/hi-ml +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e89e7d768b2b5149eb1b898ae8ee5ada01184d46 From 70771d3e9581e201bdc2b35139860b5a644d4ab6 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 14:11:16 +0000 Subject: [PATCH 30/41] fix import --- InnerEye/Common/metrics_constants.py | 1 - Tests/ML/test_model_training.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/InnerEye/Common/metrics_constants.py b/InnerEye/Common/metrics_constants.py index 75d2ed366..f484fb8e6 100644 --- a/InnerEye/Common/metrics_constants.py +++ b/InnerEye/Common/metrics_constants.py @@ -3,7 +3,6 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ from enum import Enum, unique - # String prefixes when writing training or validation set metrics to a logger from typing import Union diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index c6423da9c..995e83ddb 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd import pytest +from health_ml.utils import BatchTimeCallback from torch.utils.data import DataLoader from InnerEye.Common import fixed_paths From c964d8461bf48d770fb6b85aea70764d659aa4ec Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 14:15:23 +0000 Subject: [PATCH 31/41] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d6aab2fe2..de3851de4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ created. ## Upcoming ### Added +- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Commandline switch `monitor_gpu` to monitor + GPU utilization via Lightning's `GpuStatsMonitor`, switch `monitor_loading` to check batch loading times via + `BatchTimeCallback`, and `pl_profiler` to turn on the Lightning profiler (`simple`, `advanced`, or `pytorch`) - ([#544](https://github.com/microsoft/InnerEye-DeepLearning/pull/544)) Add documentation for segmentation model evaluation. - ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference module on test data with partial ground truth files. (Also [522](https://github.com/microsoft/InnerEye-DeepLearning/pull/522).) From a432fe22aaf57dcffb8c9e19aacd6f7e6df0c995 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 14:36:17 +0000 Subject: [PATCH 32/41] flake fix --- InnerEye/Common/metrics_constants.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/InnerEye/Common/metrics_constants.py b/InnerEye/Common/metrics_constants.py index f484fb8e6..d70a3aac1 100644 --- a/InnerEye/Common/metrics_constants.py +++ b/InnerEye/Common/metrics_constants.py @@ -6,8 +6,6 @@ # String prefixes when writing training or validation set metrics to a logger from typing import Union -from health_ml.utils import BatchTimeCallback - TRAIN_PREFIX = "train/" VALIDATION_PREFIX = "val/" From d6053358683bb767619041bc72ce47a260837883 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 3 Nov 2021 15:49:10 +0000 Subject: [PATCH 33/41] fixed logging --- .../SSL/lightning_modules/byol/byol_module.py | 20 ++-- .../ML/SSL/lightning_modules/simclr_module.py | 15 ++- .../ssl_classifier_module.py | 9 +- .../lightning_modules/ssl_online_evaluator.py | 42 ++++---- InnerEye/ML/SSL/utils.py | 8 +- InnerEye/ML/model_training.py | 7 +- Tests/SSL/test_ssl_containers.py | 99 ++++++++++++++++++- docs/WSL.md | 48 ++++++--- 8 files changed, 186 insertions(+), 62 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py b/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py index 603f45063..89b6182ad 100644 --- a/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py +++ b/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py @@ -9,14 +9,15 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F +from health_ml.utils import log_learning_rate, log_on_epoch from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pytorch_lightning import Trainer from torch import Tensor as T from torch.optim import Adam from InnerEye.ML.SSL.lightning_modules.byol.byol_models import SiameseArm from InnerEye.ML.SSL.lightning_modules.byol.byol_moving_average import ByolMovingAverageWeightUpdate from InnerEye.ML.SSL.utils import SSLDataModuleType -from pytorch_lightning import Trainer SingleBatchType = Tuple[List, T] BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType] @@ -98,14 +99,15 @@ def shared_step(self, batch: BatchType, batch_idx: int) -> T: return loss - def training_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> T: # type: ignore + def training_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> torch.Tensor: # type: ignore loss = self.shared_step(batch, batch_idx) - self.log_dict({'byol/train/loss': loss, 'byol/tau': self.weight_callback.current_tau}) + log_on_epoch(self, metrics={'byol/train/loss': loss, 'byol/tau': self.weight_callback.current_tau}) + log_learning_rate(self, name="byol/learning_rate") return loss def validation_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> T: # type: ignore loss = self.shared_step(batch, batch_idx) - self.log_dict({'byol/val/loss': loss}) + log_on_epoch(self, 'byol/val/loss', loss) return loss def setup(self, *args: Any, **kwargs: Any) -> None: @@ -116,9 +118,12 @@ def configure_optimizers(self) -> Any: # exclude certain parameters parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(), weight_decay=self.hparams.weight_decay) # type: ignore - optimizer = Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore - scheduler = LinearWarmupCosineAnnealingLR( - optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) # type: ignore + optimizer = Adam(parameters, + lr=self.hparams.learning_rate, # type: ignore + weight_decay=self.hparams.weight_decay) # type: ignore + scheduler = LinearWarmupCosineAnnealingLR(optimizer, + warmup_epochs=self.hparams.warmup_epochs, + max_epochs=self.hparams.max_epochs) # type: ignore return [optimizer], [scheduler] def exclude_from_wt_decay(self, @@ -144,4 +149,3 @@ def exclude_from_wt_decay(self, {'params': params, 'weight_decay': weight_decay}, {'params': excluded_params, 'weight_decay': 0.} ] - diff --git a/InnerEye/ML/SSL/lightning_modules/simclr_module.py b/InnerEye/ML/SSL/lightning_modules/simclr_module.py index f53446a9b..1a305951e 100644 --- a/InnerEye/ML/SSL/lightning_modules/simclr_module.py +++ b/InnerEye/ML/SSL/lightning_modules/simclr_module.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from health_ml.utils import log_learning_rate, log_on_epoch from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR from torch import Tensor as T @@ -57,6 +58,17 @@ def __init__(self, encoder_name: str, dataset_name: str, use_7x7_first_conv_in_r def forward(self, x: torch.Tensor) -> torch.Tensor: return self.encoder(x) + def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: + loss = self.shared_step(batch) + log_on_epoch(self, "simclr/train/loss", loss, sync_dist=False) + log_learning_rate(self, name="simclr/learning_rate") + return loss + + def validation_step(self, batch: BatchType, batch_idx: int) -> T: # type: ignore + loss = self.shared_step(batch) + log_on_epoch(self, "simclr/val/loss", loss, sync_dist=False) + return loss + def shared_step(self, batch: BatchType) -> T: batch = batch[SSLDataModuleType.ENCODER] if isinstance(batch, dict) else batch @@ -72,6 +84,3 @@ def shared_step(self, batch: BatchType) -> T: loss = self.nt_xent_loss(z1, z2, self.temperature) return loss - - - diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py b/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py index 805366f22..8514183ce 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py @@ -7,6 +7,7 @@ import torch from torchmetrics import Metric from pl_bolts.models.self_supervised import SSLEvaluator +from health_ml.utils import log_on_epoch from torch.nn import functional as F from InnerEye.ML.SSL.encoders import get_encoder_output_dim @@ -79,16 +80,16 @@ def shared_step(self, batch: Any, is_training: bool) -> Any: def training_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> Any: # type: ignore loss = self.shared_step(batch, True) - self.log("train/loss", loss, on_step=False, on_epoch=True) + log_on_epoch(self, "train/loss", loss) for metric in self.train_metrics: - self.log(f"train/{metric.name}", metric, on_epoch=True, on_step=False) + log_on_epoch(self, f"train/{metric.name}", metric) return loss def validation_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> None: # type: ignore loss = self.shared_step(batch, is_training=False) - self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=True) + log_on_epoch(self, 'val/loss', loss) for metric in self.val_metrics: - self.log(f"val/{metric.name}", metric, on_epoch=True, on_step=False) + log_on_epoch(self, f"val/{metric.name}", metric) def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]: """ diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 00f89b792..8afa9fc37 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -10,6 +10,7 @@ from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.evaluator import SSLEvaluator from torch import Tensor as T +from health_ml.utils import log_on_epoch from torch.nn import functional as F from torchmetrics import Metric @@ -46,6 +47,13 @@ def __init__(self, Accuracy05()] \ if self.num_classes == 2 else [Accuracy05()] self.class_weights = class_weights + self.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim, + n_classes=self.num_classes, + p=self.drop_p, + n_hidden=self.hidden_dim) + self.optimizer = torch.optim.Adam(self.non_linear_evaluator.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay) def on_save_checkpoint(self, trainer: pl.Trainer, @@ -54,7 +62,7 @@ def on_save_checkpoint(self, # Each callback gets its own state dictionary, that are fed back in during load return { OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), - EVALUATOR_STATE_NAME: pl_module.non_linear_evaluator.state_dict() + EVALUATOR_STATE_NAME: self.non_linear_evaluator.state_dict() } def on_load_checkpoint(self, @@ -62,7 +70,7 @@ def on_load_checkpoint(self, pl_module: pl.LightningModule, callback_state: Dict[str, Any]) -> None: self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME]) - pl_module.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME]) + self.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME]) def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ @@ -70,15 +78,7 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning """ for metric in [*self.train_metrics, *self.val_metrics]: metric.to(device=pl_module.device) # type: ignore - - pl_module.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim, - n_classes=self.num_classes, - p=self.drop_p, - n_hidden=self.hidden_dim).to(pl_module.device) - assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module) - self.optimizer = torch.optim.Adam(pl_module.non_linear_evaluator.parameters(), - lr=self.learning_rate, - weight_decay=self.weight_decay) + self.non_linear_evaluator = self.non_linear_evaluator.to(pl_module.device) @staticmethod def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]: @@ -106,10 +106,9 @@ def shared_step(self, batch: BatchType, pl_module: pl.LightningModule, is_traini with torch.no_grad(): representations = self.get_representations(pl_module, x) representations = representations.detach() - assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module) # Run the linear-head with SSL embeddings. - mlp_preds = pl_module.non_linear_evaluator(representations) + mlp_preds = self.non_linear_evaluator(representations) weights = None if self.class_weights is None else self.class_weights.to(device=pl_module.device) mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights) @@ -132,26 +131,29 @@ def on_validation_batch_end(self, trainer: pl.Trainer, ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist()) if ids_linear_head not in self.visited_ids: self.visited_ids.add(ids_linear_head) + old_mode = self.non_linear_evaluator.training + self.non_linear_evaluator.eval() loss = self.shared_step(batch, pl_module, is_training=False) - pl_module.log('ssl_online_evaluator/val/loss', loss, on_step=False, on_epoch=True, sync_dist=False) + log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss) for metric in self.val_metrics: - pl_module.log(f"ssl_online_evaluator/val/{metric.name}", metric, on_epoch=True, - on_step=False) # type: ignore + log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric) + self.non_linear_evaluator.train(old_mode) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore """ Get and log training metrics, perform network update. """ + # Similar code should also live in the encoder training. + # There is a silent assumption here that SSL data is larger than linear head data ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist()) if ids_linear_head not in self.visited_ids: self.visited_ids.add(ids_linear_head) loss = self.shared_step(batch, pl_module, is_training=True) self.optimizer.zero_grad() - loss.backward() + loss.backward(loss) self.optimizer.step() # log metrics - pl_module.log('ssl_online_evaluator/train/loss', loss) + log_on_epoch(pl_module, 'ssl_online_evaluator/train/loss', loss) for metric in self.train_metrics: - pl_module.log(f"ssl_online_evaluator/train/online_{metric.name}", metric, on_epoch=True, - on_step=False) # type: ignore + log_on_epoch(pl_module, f"ssl_online_evaluator/train/online_{metric.name}", metric) diff --git a/InnerEye/ML/SSL/utils.py b/InnerEye/ML/SSL/utils.py index 8cc757abb..0fa62f05b 100644 --- a/InnerEye/ML/SSL/utils.py +++ b/InnerEye/ML/SSL/utils.py @@ -81,14 +81,10 @@ def create_ssl_image_classifier(num_classes: int, logging.info(f"Loading pretrained {ssl_type} weights from:\n {pl_checkpoint_path}") if ssl_type == SSLTrainingType.BYOL.value or ssl_type == SSLTrainingType.BYOL: - # Here we need to indicate how many classes where used for linear evaluator at training time, to load the - # checkpoint (incl. linear evaluator) with strict = True - byol_module = SSLModelLoader(BYOLInnerEye, loaded_params["num_classes"]).load_from_checkpoint( - pl_checkpoint_path) + byol_module = BYOLInnerEye.load_from_checkpoint(pl_checkpoint_path) encoder = byol_module.target_network.encoder elif ssl_type == SSLTrainingType.SimCLR.value or ssl_type == SSLTrainingType.SimCLR: - simclr_module = SSLModelLoader(SimCLRInnerEye, loaded_params["num_classes"]).load_from_checkpoint( - pl_checkpoint_path) + simclr_module = SimCLRInnerEye.load_from_checkpoint(pl_checkpoint_path) encoder = simclr_module.encoder else: raise NotImplementedError(f"Unknown unsupervised model: {ssl_type}") diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index dea99f984..bc5297bbb 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, TypeVar +from health_azure.utils import is_global_rank_zero, is_local_rank_zero +from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback, log_on_epoch from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger @@ -24,8 +26,6 @@ from InnerEye.ML.lightning_loggers import StoringLogger from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \ get_subject_output_file_per_rank -from health_azure.utils import is_global_rank_zero, is_local_rank_zero -from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback TEMP_PREFIX = "temp/" @@ -72,7 +72,8 @@ def __init__(self, container: LightningContainer): save_last=False) def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unused: bool = None) -> None: - pl_module.log(name="epoch_started", value=trainer.current_epoch) # type: ignore + # The metric to monitor must be logged on all ranks in distributed training + log_on_epoch(pl_module, name="epoch_started", value=trainer.current_epoch, sync_dist=False) # type: ignore def create_lightning_trainer(container: LightningContainer, diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 2f455ae38..d56033a9e 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import math from pathlib import Path from typing import Dict from unittest import mock @@ -11,6 +12,7 @@ import pytest import torch from pl_bolts.models.self_supervised.resnets import ResNet +from torch.optim.lr_scheduler import _LRScheduler from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import is_windows @@ -62,6 +64,28 @@ def default_runner() -> Runner: "--num_workers=0"] +def _compare_stored_metrics(runner: Runner, expected_metrics: Dict[str, float], abs: float = 1e-5) -> None: + """ + Checks if the StoringLogger in the given runner holds all the expected metrics as results of training + epoch 0, up to a given absolute precision. + :param runner: The Innereye runner. + :param expected_metrics: A dictionary with all metrics that are expected to be present. + """ + assert runner.ml_runner is not None + assert runner.ml_runner.storing_logger is not None + print(f"Actual metrics in epoch 0: {runner.ml_runner.storing_logger.results_per_epoch[0]}") + print(f"Expected metrics: {expected_metrics}") + for metric, expected in expected_metrics.items(): + actual = runner.ml_runner.storing_logger.results_per_epoch[0][metric] + if isinstance(actual, float): + if math.isnan(expected): + assert math.isnan(actual), f"Metric {metric}: Expected NaN, but got: {actual}" + else: + assert actual == pytest.approx(expected, abs=abs), f"Mismatch for metric {metric}" + else: + assert actual == expected, f"Mismatch for metric {metric}" + + @pytest.mark.skipif(is_windows(), reason="Too slow on windows") def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: """ @@ -72,8 +96,9 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: - training of image classifier for one epoch. """ args = common_test_args + ["--model=CIFAR10SimCLR"] + runner = default_runner() with mock.patch("sys.argv", args): - loaded_config, actual_run = default_runner().run() + loaded_config, actual_run = runner.run() assert loaded_config is not None assert isinstance(loaded_config.model, SimCLRInnerEye) assert loaded_config.encoder_output_dim == 2048 @@ -82,10 +107,24 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: assert loaded_config.recovery_checkpoint_save_interval == 200 assert loaded_config.ssl_training_type == SSLTrainingType.SimCLR assert loaded_config.online_eval.num_classes == 10 - assert loaded_config.ssl_training_dataset_name == SSLDatasetName.CIFAR10 assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10.value + assert loaded_config.ssl_training_dataset_name == SSLDatasetName.CIFAR10 assert not loaded_config.use_balanced_binary_loss_for_linear_head assert isinstance(loaded_config.model.encoder.cnn_model, ResNet) + + # Check the metrics that were recorded during training + expected_metrics = { + 'simclr/train/loss': 3.423144578933716, + 'simclr/learning_rate': 0.0, + 'ssl_online_evaluator/train/loss': 2.6143882274627686, + 'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.0, + 'epoch_started': 0.0, + 'simclr/val/loss': 2.886892795562744, + 'ssl_online_evaluator/val/loss': 2.2472469806671143, + 'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.20000000298023224 + } + _compare_stored_metrics(runner, expected_metrics) + # Check that the checkpoint contains both the optimizer for the embedding and for the linear head checkpoint_path = loaded_config.outputs_folder / "checkpoints" / "best_checkpoint.ckpt" checkpoint = torch.load(checkpoint_path) @@ -160,6 +199,24 @@ def test_innereye_ssl_container_rsna() -> None: assert loaded_config.datamodule_args[SSLDataModuleType.ENCODER].augmentation_params.augmentation.use_random_crop assert loaded_config.datamodule_args[SSLDataModuleType.ENCODER].augmentation_params.augmentation.use_random_affine + expected_metrics = { + 'byol/train/loss': 0.00401744619011879, + 'byol/tau': 0.9899999499320984, + 'byol/learning_rate/0/0': 0.0, + 'byol/learning_rate/0/1': 0.0, + 'ssl_online_evaluator/train/loss': 0.685592532157898, + 'ssl_online_evaluator/train/online_AreaUnderRocCurve': 0.5, + 'ssl_online_evaluator/train/online_AreaUnderPRCurve': 0.699999988079071, + 'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.4000000059604645, + 'epoch_started': 0.0, + 'byol/val/loss': -0.07644838094711304, + 'ssl_online_evaluator/val/loss': 0.6965796947479248, + 'ssl_online_evaluator/val/AreaUnderRocCurve': math.nan, + 'ssl_online_evaluator/val/AreaUnderPRCurve': math.nan, + 'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.0 + } + _compare_stored_metrics(runner, expected_metrics) + # Check that we are able to load the checkpoint and create classifier model checkpoint_path = loaded_config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX args = common_test_args + ["--model=CXRImageClassifier", @@ -173,3 +230,41 @@ def test_innereye_ssl_container_rsna() -> None: assert loaded_config.model.freeze_encoder assert torch.isclose(loaded_config.model.class_weights, torch.tensor([0.21, 0.79]), atol=1e-6).all() # type: ignore assert loaded_config.model.num_classes == 2 + + +def test_simclr_lr_scheduler() -> None: + """ + Test if the LR scheduler has the expected warmup behaviour. + """ + num_samples = 100 + batch_size = 20 + gpus = 1 + max_epochs = 10 + warmup_epochs = 2 + model = SimCLRInnerEye(encoder_name="resnet18", dataset_name="CIFAR10", + gpus=gpus, num_samples=num_samples, batch_size=batch_size, + max_epochs=max_epochs, warmup_epochs=warmup_epochs) + # The LR scheduler used here works per step. Scheduler computes the total number of steps, in this example that's 5 + train_iters_per_epoch = num_samples / (batch_size * gpus) + assert model.train_iters_per_epoch == train_iters_per_epoch + # Mock a second optimizer that is normally created in the SSL container + linear_head_optimizer = mock.MagicMock() + model.online_eval_optimizer = linear_head_optimizer + # Retrieve the scheduler and iterate it + _, scheduler_list = model.configure_optimizers() + assert isinstance(scheduler_list[0], dict) + assert scheduler_list[0]["interval"] == "step" + scheduler = scheduler_list[0]["scheduler"] + assert isinstance(scheduler, _LRScheduler) + lr = [] + for i in range(0, int(max_epochs * train_iters_per_epoch)): + scheduler.step() + lr.append(scheduler.get_last_lr()[0]) + # The highest learning rate is expected after the warmup epochs + highest_lr = np.argmax(lr) + assert highest_lr == int(warmup_epochs * train_iters_per_epoch - 1) + + for i in range(0, highest_lr): + assert lr[i] < lr[i + 1], f"Not strictly monotonically increasing at index {i}" + for i in range(highest_lr, len(lr) - 1): + assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}" diff --git a/docs/WSL.md b/docs/WSL.md index c0b581ec6..6d8d93525 100644 --- a/docs/WSL.md +++ b/docs/WSL.md @@ -1,18 +1,21 @@ # How to use the Windows Subsystem for Linux (WSL2) for development We are aware of two issues with running our toolbox on Windows: -- Conda and miniconda can be rather temperamental: Environment creation can fail with package conflict errors -of unclear origin, or internal conda errors. + +- Conda and miniconda can be rather temperamental: Environment creation can fail with package conflict errors of unclear + origin, or internal conda errors. - Some features of PyTorch are not supported, or not well supported, on Windows. -If you are facing issue of the above kind on a Windows machine, we would highly recommend working with the -Windows Subsystem for Linux (WSL2) or a plain Ubuntu Linux box. +If you are facing issue of the above kind on a Windows machine, we would highly recommend working with the Windows +Subsystem for Linux (WSL2) or a plain Ubuntu Linux box. ## Enable CUDA in WSL2 -If you are running a Windows box with a GPU, please follow the documentation + +If you are running a Windows box with a GPU, please follow the documentation [here](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) to access the GPU from within WSL2. -You can also find a video walkthrough of WSL2+CUDA installation here: https://channel9.msdn.com/Shows/Tabs-vs-Spaces/GPU-Accelerated-Machine-Learning-with-WSL-2 +You can also find a video walkthrough of WSL2+CUDA installation +here: https://channel9.msdn.com/Shows/Tabs-vs-Spaces/GPU-Accelerated-Machine-Learning-with-WSL-2 ## Install WSL2 @@ -21,15 +24,18 @@ Requirements: Windows 10 version 2004 or higher The instructions are [here](https://docs.microsoft.com/en-us/windows/wsl/install-win10), but summarized in copy/paste-able form below. When installing via the UI, pick Ubuntu version 20.04 LTS as your distribution. -To use the commandline setup, please first install +To use the commandline setup, please first install [winget via the appxbundle](https://github.com/microsoft/winget-cli/releases). Then, in PowerShell as Administrator: + ``` dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux /all /norestart dism.exe /online /enable-feature /featurename:VirtualMachinePlatform /all /norestart ``` + Restart your machine, then again in PowerShell: + ``` wsl --set-default-version 2 winget install ubuntu --version 20.04 @@ -39,23 +45,25 @@ wsl --set-default-version 2 winget install Microsoft.WindowsTerminal ``` -Remember to restart your machine if you were doing a fresh installation of WSL 2 before trying further steps. +Remember to restart your machine if you were doing a fresh installation of WSL 2 before trying further steps. -Since it is possible to choose the version of WSL that a particular distribution is running, -once you have WSL2 installed, ensure that your distribution is running on top of WSL2 by executing +Since it is possible to choose the version of WSL that a particular distribution is running, once you have WSL2 +installed, ensure that your distribution is running on top of WSL2 by executing `wsl --list --verbose` -If all is good, the output should look like this: +If all is good, the output should look like this: + ``` $> wsl --list -v NAME STATE VERSION * Ubuntu-20.04 Running 2 ``` -Note the "2" in Version column. +Note the "2" in Version column. ## Install git and Anaconda Start the Windows Terminal app, create an Ubuntu tab. In the shell, run the following commands: + - `sudo apt update` - `sudo apt install git git-lfs python-dev build-essential` - `wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh` @@ -64,6 +72,7 @@ Start the Windows Terminal app, create an Ubuntu tab. In the shell, run the foll - Clone repo or access your repos via /mnt/c/... - Create conda environment: `conda env create --file environment.yml` - Clean your pyc files (in case you have some left from Windows): + ``` find * -name '*.pyc' | xargs -d'\n' rm` ``` @@ -71,13 +80,20 @@ find * -name '*.pyc' | xargs -d'\n' rm` ## Configure PyCharm - https://www.jetbrains.com/help/pycharm/using-wsl-as-a-remote-interpreter.html -- You might need to reset all your firewall settings to make the debugger work with PyCharm. This can be done with these PowerShell commands (as Administrator): +- You might need to reset all your firewall settings to make the debugger work with PyCharm. This can be done with these + PowerShell commands (as Administrator): + ``` -Remove-NetFirewallRule -$myIp = (Ubuntu1804 run "cat /etc/resolv.conf | grep nameserver | cut -d' ' -f2") +$myIp = (Ubuntu2004 run "cat /etc/resolv.conf | grep nameserver | cut -d' ' -f2") New-NetFirewallRule -DisplayName "WSL" -Direction Inbound -LocalAddress $myIp -Action Allow ``` -- Then (re)start PyCharm. If asked whether to give it permission to communicate over domain, private and public networks, make sure all three are ticked. + +- Then (re)start PyCharm. If asked whether to give it permission to communicate over domain, private and public + networks, make sure all three are ticked. +- If you are still struggling with the firewall rules, consider removing all your current firewall rules, by running + `Remove-NetFirewallRule` in the PowerShell. WARNING: This will remove all your present firewall rules, and you may + need to repeat the firewall setup for other programs that you have installed! ## Configure VSCode + - https://code.visualstudio.com/docs/remote/wsl From 7f75ec3aa85c69a479acc1d046e80a68aab182ef Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Fri, 12 Nov 2021 16:47:02 +0000 Subject: [PATCH 34/41] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index de3851de4..4c0c8c696 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ gets uploaded to AzureML, by skipping all test folders. `ScalarModelBase`. - ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Updated report in CovidModel. Set parameters in the config to run inference on both the validation and test sets by default. +- ([#584](https://github.com/microsoft/InnerEye-DeepLearning/pull/584)) SSL models write the optimizer state for the linear head to the checkpoint now. - ([#566](https://github.com/microsoft/InnerEye-DeepLearning/pull/566)) Update `hi-ml` dependency to `hi-ml-azure`. - ([#572](https://github.com/microsoft/InnerEye-DeepLearning/pull/572)) Updated to new version of hi-ml package From c22362e711d6d9a5dcafad88107fb4d56022d682 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Fri, 12 Nov 2021 19:47:45 +0000 Subject: [PATCH 35/41] mypy --- InnerEye/ML/SSL/lightning_modules/byol/byol_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py b/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py index 89b6182ad..565e7c3c8 100644 --- a/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py +++ b/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py @@ -122,7 +122,7 @@ def configure_optimizers(self) -> Any: lr=self.hparams.learning_rate, # type: ignore weight_decay=self.hparams.weight_decay) # type: ignore scheduler = LinearWarmupCosineAnnealingLR(optimizer, - warmup_epochs=self.hparams.warmup_epochs, + warmup_epochs=self.hparams.warmup_epochs, # type: ignore max_epochs=self.hparams.max_epochs) # type: ignore return [optimizer], [scheduler] From 2da000b0e0f48a2e42b1e7ba94660bf61fd971a4 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Sat, 13 Nov 2021 04:53:42 +0000 Subject: [PATCH 36/41] test fix --- InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py | 4 ++-- Tests/SSL/test_ssl_containers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 8afa9fc37..3724ce773 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -78,7 +78,7 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning """ for metric in [*self.train_metrics, *self.val_metrics]: metric.to(device=pl_module.device) # type: ignore - self.non_linear_evaluator = self.non_linear_evaluator.to(pl_module.device) + self.non_linear_evaluator.to(pl_module.device) @staticmethod def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]: @@ -150,7 +150,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data self.visited_ids.add(ids_linear_head) loss = self.shared_step(batch, pl_module, is_training=True) self.optimizer.zero_grad() - loss.backward(loss) + loss.backward() self.optimizer.step() # log metrics diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index d56033a9e..0481df212 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -123,7 +123,7 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: 'ssl_online_evaluator/val/loss': 2.2472469806671143, 'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.20000000298023224 } - _compare_stored_metrics(runner, expected_metrics) + _compare_stored_metrics(runner, expected_metrics, abs=5e-5) # Check that the checkpoint contains both the optimizer for the embedding and for the linear head checkpoint_path = loaded_config.outputs_folder / "checkpoints" / "best_checkpoint.ckpt" From 588dd01aa607d5b688f5ddb02d0017b3b8e9f111 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 17 Nov 2021 20:22:40 +0000 Subject: [PATCH 37/41] PR comments --- .../ML/SSL/lightning_containers/ssl_image_classifier.py | 2 ++ InnerEye/ML/SSL/lightning_modules/simclr_module.py | 5 ++--- InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py | 4 ++++ InnerEye/ML/model_training.py | 4 ++-- docs/self_supervised_models.md | 6 ++++-- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py b/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py index e890b75ca..508266b25 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py @@ -64,4 +64,6 @@ def get_data_module(self) -> InnerEyeDataModuleTypes: return self.data_module def get_trainer_arguments(self) -> Dict[str, Any]: + # This class inherits from SSLContainer, where the get_trainer_arguments adds the online evaluator callback. + # We don't need that for the classifier, hence need to return an empty set of trainer arguments. return {} diff --git a/InnerEye/ML/SSL/lightning_modules/simclr_module.py b/InnerEye/ML/SSL/lightning_modules/simclr_module.py index 1a305951e..62b1ee8ec 100644 --- a/InnerEye/ML/SSL/lightning_modules/simclr_module.py +++ b/InnerEye/ML/SSL/lightning_modules/simclr_module.py @@ -10,7 +10,6 @@ import torch.nn.functional as F from health_ml.utils import log_learning_rate, log_on_epoch from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR -from torch import Tensor as T from InnerEye.ML.SSL.encoders import SSLEncoder from InnerEye.ML.SSL.utils import SSLDataModuleType @@ -64,12 +63,12 @@ def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: log_learning_rate(self, name="simclr/learning_rate") return loss - def validation_step(self, batch: BatchType, batch_idx: int) -> T: # type: ignore + def validation_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: # type: ignore loss = self.shared_step(batch) log_on_epoch(self, "simclr/val/loss", loss, sync_dist=False) return loss - def shared_step(self, batch: BatchType) -> T: + def shared_step(self, batch: BatchType) -> torch.Tensor: batch = batch[SSLDataModuleType.ENCODER] if isinstance(batch, dict) else batch (img1, img2), y = batch diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 3724ce773..16eb1161b 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -127,16 +127,20 @@ def on_validation_batch_end(self, trainer: pl.Trainer, dataloader_idx: int) -> None: # type: ignore """ Get and log validation metrics. + Metrics are computed only if the sample IDs in the batch have not yet been seen in this epoch (linear head + data may be repeated if the SSL data is longer than the linear head data). """ ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist()) if ids_linear_head not in self.visited_ids: self.visited_ids.add(ids_linear_head) + # Put the online evaluator into "eval" mode old_mode = self.non_linear_evaluator.training self.non_linear_evaluator.eval() loss = self.shared_step(batch, pl_module, is_training=False) log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss) for metric in self.val_metrics: log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric) + # Put the online evaluator back into the state (eval or train) that it was before calling this method self.non_linear_evaluator.train(old_mode) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index bc5297bbb..6115c4629 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -156,14 +156,14 @@ def create_lightning_trainer(container: LightningContainer, callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate, write_to_logging_info=True, print_timestamp=False)) - # Read out additional model-specific args here. - # We probably want to keep essential ones like numgpu and logging. trainer = Trainer(default_root_dir=str(container.outputs_folder), deterministic=deterministic, benchmark=benchmark, accelerator=accelerator, plugins=plugins, max_epochs=container.num_epochs, + # Both these arguments can be integers or floats. If integers, it is the number of batches. + # If float, it's the fraction of batches. We default to 1.0 (processing all batches). limit_train_batches=container.pl_limit_train_batches or 1.0, limit_val_batches=container.pl_limit_val_batches or 1.0, num_sanity_val_steps=container.pl_num_sanity_val_steps, diff --git a/docs/self_supervised_models.md b/docs/self_supervised_models.md index c93afee2b..dfbeba286 100644 --- a/docs/self_supervised_models.md +++ b/docs/self_supervised_models.md @@ -107,8 +107,10 @@ with the following available arguments: * `ssl_encoder`: name of the encoder to train, member of `EncoderName` class, currently supported are resnet50, resnet101 and densenet121, * `ssl_training_type`: which SSL algorithm to use, member of `SSLType` choice between BYOL and SimCLR, -* `ssl_training_batch_size`: batch size of SSL training -* `linear_head_batch_size`: batch size for linear head training (used for monitor of SSL embeddings quality) +* `ssl_training_batch_size`: batch size of SSL training. This is the number of examples processed by a single GPU. + Multiply this by the number of GPUs to get the effective batch size. +* `linear_head_batch_size`: batch size for linear head training (used for monitor of SSL embeddings quality). This is + the number of examples processed by a single GPU. Multiply this by the number of GPUs to get the effective batch size. * `ssl_augmentation_config`: path to yaml config for augmentation to use during SSL training. Only used for NIH/Kaggle datasets. * `linear_head_augmentation_config`: path to yaml config for augmentation to use for linear head training. Only used for From 44b8d0e80014c8502ee0bca3791afbfc30965177 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 17 Nov 2021 22:33:28 +0000 Subject: [PATCH 38/41] fix --- CHANGELOG.md | 1 + InnerEye/Azure/azure_runner.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac9f8f5a7..6ab371472 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,7 @@ gets uploaded to AzureML, by skipping all test folders. - ([#572](https://github.com/microsoft/InnerEye-DeepLearning/pull/572)) Updated to new version of hi-ml package ### Fixed +- ([#593](https://github.com/microsoft/InnerEye-DeepLearning/pull/593)) Bug fix for hi-ml 0.1.11 issue (#130): empty mount point is turned into ".", which fails the AML job - ([#587](https://github.com/microsoft/InnerEye-DeepLearning/pull/587)) Bug fix for regression in AzureML's handling of environments: upgrade to hi-ml 0.1.11 - ([#537](https://github.com/microsoft/InnerEye-DeepLearning/pull/537)) Print warning if inference is disabled but comparison requested. - ([#567](https://github.com/microsoft/InnerEye-DeepLearning/pull/567)) fix pillow version. diff --git a/InnerEye/Azure/azure_runner.py b/InnerEye/Azure/azure_runner.py index 62f7cb8aa..b0330c117 100644 --- a/InnerEye/Azure/azure_runner.py +++ b/InnerEye/Azure/azure_runner.py @@ -111,7 +111,9 @@ def create_dataset_configs(azure_config: AzureConfig, for i, (dataset_id, mount_point) in enumerate(zip(all_azure_dataset_ids, all_dataset_mountpoints)): if dataset_id: datasets.append(DatasetConfig(name=dataset_id, - target_folder=mount_point, + # Workaround for a bug in hi-ml 0.1.11: mount_point=="" creates invalid jobs, + # setting to None works. + target_folder=mount_point or None, use_mounting=azure_config.use_dataset_mount, datastore=azure_config.azureml_datastore)) elif mount_point: From ea77b470cacdbd1c3fb16d4a9cf2c87f713adedb Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 17 Nov 2021 23:25:15 +0000 Subject: [PATCH 39/41] fix --- InnerEye/Common/fixed_paths.py | 1 - .../ML/visualizers/plot_cross_validation.py | 1 - Tests/AfterTraining/test_after_training.py | 18 ++++++++++++------ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/InnerEye/Common/fixed_paths.py b/InnerEye/Common/fixed_paths.py index 34f543343..e21ea1400 100755 --- a/InnerEye/Common/fixed_paths.py +++ b/InnerEye/Common/fixed_paths.py @@ -33,7 +33,6 @@ def repository_root_directory(path: Optional[PathOrString] = None) -> Path: DEFAULT_RESULT_IMAGE_NAME = "segmentation.nii.gz" # Default filename if scoring produces a zipped DICOM-RT file. DEFAULT_RESULT_ZIP_DICOM_NAME = "segmentation.dcm.zip" -DEFAULT_AML_LOGS_DIR = "azureml-logs" DEFAULT_LOGS_DIR_NAME = "logs" diff --git a/InnerEye/ML/visualizers/plot_cross_validation.py b/InnerEye/ML/visualizers/plot_cross_validation.py index 932ed33c5..b1a2324b6 100644 --- a/InnerEye/ML/visualizers/plot_cross_validation.py +++ b/InnerEye/ML/visualizers/plot_cross_validation.py @@ -49,7 +49,6 @@ RUN_DICTIONARY_NAME = "RunDictionary.txt" MAX_STRUCTURES_PER_PLOT = 7 -DRIVER_LOG_BASENAME = "70_driver_log.txt" RUN_RECOVERY_ID_KEY = 'run_recovery_id' WILCOXON_RESULTS_FILE = "CrossValidationWilcoxonSignedRankTestResults.txt" MANN_WHITNEY_RESULTS_FILE = "CrossValidationMannWhitneyTestResults.txt" diff --git a/Tests/AfterTraining/test_after_training.py b/Tests/AfterTraining/test_after_training.py index 1a347291d..b1e00f81a 100644 --- a/Tests/AfterTraining/test_after_training.py +++ b/Tests/AfterTraining/test_after_training.py @@ -29,7 +29,7 @@ from InnerEye.Common import common_util, fixed_paths, fixed_paths_for_tests from InnerEye.Common.common_util import BEST_EPOCH_FOLDER_NAME, CROSSVAL_RESULTS_FOLDER, ENSEMBLE_SPLIT_NAME, \ get_best_epoch_results_path -from InnerEye.Common.fixed_paths import (DEFAULT_AML_LOGS_DIR, DEFAULT_RESULT_IMAGE_NAME, DEFAULT_RESULT_ZIP_DICOM_NAME, +from InnerEye.Common.fixed_paths import (DEFAULT_RESULT_IMAGE_NAME, DEFAULT_RESULT_ZIP_DICOM_NAME, PYTHON_ENVIRONMENT_NAME, repository_root_directory) from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path from InnerEye.Common.output_directories import OutputFolderForTests @@ -50,7 +50,7 @@ from InnerEye.Scripts import submit_for_inference from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_default_workspace, get_nifti_shape -FALLBACK_SINGLE_RUN = "refs_pull_545_merge:refs_pull_545_merge_1626538212_d2b07afd" +FALLBACK_SINGLE_RUN = "refs_pull_593_merge_1637188926_7ba554ba" FALLBACK_ENSEMBLE_RUN = "refs_pull_545_merge:HD_caea82ae-9603-48ba-8280-7d2bc6272411" FALLBACK_2NODE_RUN = "refs_pull_545_merge:refs_pull_545_merge_1626538178_9f3023b2" FALLBACK_CV_GLAUCOMA = "refs_pull_545_merge:HD_72ecc647-07c3-4353-a538-620346114ebd" @@ -200,10 +200,16 @@ def test_check_dataset_mountpoint(test_output_dirs: OutputFolderForTests) -> Non """ run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN) files = run.get_file_names() - driver_log = f"{DEFAULT_AML_LOGS_DIR}/70_driver_log.txt" - assert driver_log in files - downloaded = test_output_dirs.root_dir / "70_driver_log.txt" - run.download_file(driver_log, output_file_path=str(downloaded)) + + # Account for old and new job runtime: log files live in different places + driver_log_files = ["azureml-logs/70_driver_log.txt", "user_logs/std_log.txt"] + downloaded = test_output_dirs.root_dir / "driver_log.txt" + for f in driver_log_files: + if f in files: + run.download_file(f, output_file_path=str(downloaded)) + break + else: + raise ValueError("The run does not contain any of the driver log files") logs = downloaded.read_text() expected_mountpoint = BasicModel2Epochs().dataset_mountpoint assert f"local_dataset : {expected_mountpoint}" in logs From 5703c1c321571483ad698e682971ea76aed2969b Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 18 Nov 2021 14:47:14 +0000 Subject: [PATCH 40/41] fix --- InnerEye/ML/SSL/lightning_modules/simclr_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/InnerEye/ML/SSL/lightning_modules/simclr_module.py b/InnerEye/ML/SSL/lightning_modules/simclr_module.py index 62b1ee8ec..0ab1dd566 100644 --- a/InnerEye/ML/SSL/lightning_modules/simclr_module.py +++ b/InnerEye/ML/SSL/lightning_modules/simclr_module.py @@ -14,7 +14,7 @@ from InnerEye.ML.SSL.encoders import SSLEncoder from InnerEye.ML.SSL.utils import SSLDataModuleType -SingleBatchType = Tuple[List, T] +SingleBatchType = Tuple[List, torch.Tensor] BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType] From 4c4f4eb4118c43ca062f48fecfe7e89cf50d012f Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 18 Nov 2021 14:53:25 +0000 Subject: [PATCH 41/41] PR comments --- InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py | 2 ++ InnerEye/ML/configs/ssl/CXR_SSL_configs.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py b/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py index 8ca2cb58c..5eee96e2e 100644 --- a/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py +++ b/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py @@ -15,6 +15,7 @@ class CIFAR10SimCLR(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR10, + # We usually train this model with 4 GPUs, giving an effective batch size of 512 ssl_training_batch_size=128, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.SimCLR, @@ -32,6 +33,7 @@ class CIFAR10BYOL(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR10, + # We usually train this model with 4 GPUs, giving an effective batch size of 512 ssl_training_batch_size=128, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, diff --git a/InnerEye/ML/configs/ssl/CXR_SSL_configs.py b/InnerEye/ML/configs/ssl/CXR_SSL_configs.py index 24f912bc9..876522328 100644 --- a/InnerEye/ML/configs/ssl/CXR_SSL_configs.py +++ b/InnerEye/ML/configs/ssl/CXR_SSL_configs.py @@ -29,6 +29,7 @@ def __init__(self) -> None: random_seed=1, recovery_checkpoint_save_interval=200, num_epochs=1000, + # We usually train this model with 16 GPUs, giving an effective batch size of 1200 ssl_training_batch_size=75, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, @@ -45,6 +46,7 @@ def __init__(self) -> None: random_seed=1, recovery_checkpoint_save_interval=200, num_epochs=1000, + # We usually train this model with 16 GPUs, giving an effective batch size of 1200 ssl_training_batch_size=75, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.SimCLR,