diff --git a/CHANGELOG.md b/CHANGELOG.md index ff8b390c2..6921ad451 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ nodes in AzureML. Example: Add `--num_nodes=2` to the commandline arguments to t - ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) When registering a model, the name of the Python execution environment is added as a tag. This tag is read when running inference, and the execution environment is re-used. +- ([#411](https://github.com/microsoft/InnerEye-DeepLearning/pull/411)) Upgraded to PyTorch 1.8.0, PyTorch-Lightning +1.1.8 and AzureML SDK 1.23.0 ### Fixed diff --git a/InnerEye/Azure/azure_runner.py b/InnerEye/Azure/azure_runner.py index 9fcaf813f..853311720 100644 --- a/InnerEye/Azure/azure_runner.py +++ b/InnerEye/Azure/azure_runner.py @@ -15,12 +15,10 @@ from typing import Any, Dict, List, Optional from azureml.core import Dataset, Environment, Experiment, Run, ScriptRunConfig -from azureml.core.conda_dependencies import CondaDependencies from azureml.core.datastore import Datastore from azureml.core.runconfig import MpiConfiguration, RunConfiguration from azureml.core.workspace import WORKSPACE_DEFAULT_BLOB_STORE_NAME from azureml.data import FileDataset -from azureml.train.dnn import PyTorch from InnerEye.Azure import azure_util from InnerEye.Azure.azure_config import AzureConfig, ParserResult, SourceConfig @@ -219,22 +217,6 @@ def get_or_create_dataset(azure_config: AzureConfig, return azureml_dataset -def pytorch_version_from_conda_dependencies(conda_dependencies: CondaDependencies) -> Optional[str]: - """ - Given a CondaDependencies object, look for a spec of the form "pytorch=...", and return - whichever supported version is compatible with the value, or None if there isn't one. - """ - supported_versions = PyTorch.get_supported_versions() - for spec in conda_dependencies.conda_packages: - components = spec.split("=") - if len(components) == 2 and components[0] == "pytorch": - version = components[1] - for supported in supported_versions: - if version.startswith(supported) or supported.startswith(version): - return supported - return None - - def get_or_create_python_environment(azure_config: AzureConfig, source_config: SourceConfig, environment_name: str = "", diff --git a/InnerEye/ML/dataset/scalar_dataset.py b/InnerEye/ML/dataset/scalar_dataset.py index 9295f2f81..e9ae555e4 100644 --- a/InnerEye/ML/dataset/scalar_dataset.py +++ b/InnerEye/ML/dataset/scalar_dataset.py @@ -28,7 +28,6 @@ from InnerEye.ML.utils.features_util import FeatureStatistics from InnerEye.ML.utils.transforms import Compose3D, Transform3D - T = TypeVar('T', bound=ScalarDataSource) @@ -167,8 +166,8 @@ def load_single_data_source(subject_rows: pd.DataFrame, additional scalar values should be read from. THe keys should map each feature to its channels. :param numerical_columns: The names of all columns where additional scalar values should be read from. :param categorical_data_encoder: Encoding scheme for categorical data. - :param is_classification_dataset: If the current dataset is classification or not. - from. + :param is_classification_dataset: If True, the dataset will be used in a classification model. If False, + assume that the dataset will be used in a regression model. :param transform_labels: a label transformation or a list of label transformation to apply to the labels. If a list is provided, the transformations are applied in order from left to right. :param sequence_position_numeric: Numeric position of the data source in a data sequence. Assumed to be @@ -330,7 +329,7 @@ def __init__(self, :param subject_column: The name of the column that contains the subject identifier :param channel_column: The name of the column that contains the row identifier ("channels") that are expected to be loaded from disk later because they are large images. - :param is_classification_dataset: If the current dataset is classification or not from. + :param is_classification_dataset: If the current dataset is classification or not. :param categorical_data_encoder: Encoding scheme for categorical data. """ self.categorical_data_encoder = categorical_data_encoder diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 1006f2369..b26ab320a 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -105,10 +105,21 @@ def close_all_loggers(self) -> None: self.train_epoch_metrics_logger.flush() self.val_epoch_metrics_logger.flush() + @property + def use_sync_dist(self) -> bool: + """ + Returns True if metric logging should use sync_dist=True. This is read off from the use_ddp flag of the trainer. + """ + # For PL from version 1.2.0 on: self.trainer.accelerator_connector.use_ddp + return self.trainer.use_ddp + def on_train_epoch_start(self) -> None: self.train_timers.reset() def training_epoch_end(self, outputs: List[Any]) -> None: + # Write out all the metrics that have been accumulated in the StoringLogger in the previous epoch. + # Metrics for the very last epoch are written in on_train_end + self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch - 1) self.training_or_validation_epoch_end(is_training=True) def on_validation_epoch_start(self) -> None: @@ -141,14 +152,6 @@ def validation_epoch_end(self, outputs: List[Any]) -> None: self.random_state.restore_random_state() self.training_or_validation_epoch_end(is_training=False) - @rank_zero_only - def on_epoch_end(self) -> None: - """ - This hook is called once per epoch, before on_train_epoch_end. Use it to write out all the metrics - that have been accumulated in the StoringLogger in the previous epoch. - """ - self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch - 1) - @rank_zero_only def on_train_end(self) -> None: """ @@ -157,6 +160,7 @@ def on_train_end(self) -> None: """ self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch) + @rank_zero_only def read_epoch_results_from_logger_and_store(self, epoch: int) -> None: """ Reads the metrics for the previous epoch from the StoringLogger, and writes them to disk, broken down by @@ -167,8 +171,6 @@ def read_epoch_results_from_logger_and_store(self, epoch: int) -> None: for is_training, prefix in [(True, TRAIN_PREFIX), (False, VALIDATION_PREFIX)]: metrics = self.storing_logger.extract_by_prefix(epoch, prefix) self.store_epoch_results(metrics, epoch, is_training) - else: - print(f"Skipping, no results for {epoch}") @rank_zero_only def training_or_validation_epoch_end(self, is_training: bool) -> None: @@ -234,7 +236,7 @@ def log_on_epoch(self, if isinstance(value, numbers.Number): value = torch.tensor(value, dtype=torch.float, device=self.device) prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX - sync_dist = self.use_ddp if sync_dist_override is None else sync_dist_override + sync_dist = self.use_sync_dist if sync_dist_override is None else sync_dist_override self.log(prefix + metric_name, value, sync_dist=sync_dist, on_step=False, on_epoch=True, diff --git a/InnerEye/ML/lightning_metrics.py b/InnerEye/ML/lightning_metrics.py index 0881e18ff..32e710ebf 100644 --- a/InnerEye/ML/lightning_metrics.py +++ b/InnerEye/ML/lightning_metrics.py @@ -10,8 +10,7 @@ import torch.nn.functional as F from pytorch_lightning import metrics from pytorch_lightning.metrics import Metric -from pytorch_lightning.metrics.functional import roc -from pytorch_lightning.metrics.functional.classification import accuracy, auc, auroc, precision_recall_curve +from pytorch_lightning.metrics.functional import accuracy, auc, auroc, precision_recall_curve, roc from torch.nn import ModuleList from InnerEye.Common.metrics_constants import AVERAGE_DICE_SUFFIX, MetricType, TRAIN_PREFIX, VALIDATION_PREFIX @@ -162,6 +161,9 @@ def _get_metrics_at_optimal_cutoff(self) -> Tuple[torch.Tensor, torch.Tensor, to if torch.unique(targets).numel() == 1: return torch.tensor(np.nan), torch.tensor(np.nan), torch.tensor(np.nan), torch.tensor(np.nan) fpr, tpr, thresholds = roc(preds, targets) + assert isinstance(fpr, torch.Tensor) + assert isinstance(tpr, torch.Tensor) + assert isinstance(thresholds, torch.Tensor) optimal_idx = torch.argmax(tpr - fpr) optimal_threshold = thresholds[optimal_idx] acc = accuracy(preds > optimal_threshold, targets) @@ -246,13 +248,13 @@ def compute(self) -> torch.Tensor: if torch.unique(targets).numel() == 1: return torch.tensor(np.nan) prec, recall, _ = precision_recall_curve(preds, targets) - return auc(recall, prec) + return auc(recall, prec) # type: ignore class BinaryCrossEntropyWithLogits(ScalarMetricsBase): """ Computes the cross entropy for binary classification. - This metric must be computed off the output logits + This metric must be computed off the model output logits. """ def __init__(self) -> None: @@ -260,7 +262,8 @@ def __init__(self) -> None: def compute(self) -> torch.Tensor: preds, targets = self._get_preds_and_targets() - return F.binary_cross_entropy_with_logits(input=preds, target=targets) + # All classification metrics work with integer targets, but this one does not. Convert to float. + return F.binary_cross_entropy_with_logits(input=preds, target=targets.to(dtype=preds.dtype)) class MetricForMultipleStructures(torch.nn.Module): @@ -320,5 +323,12 @@ def compute_all(self) -> Iterator[Tuple[str, torch.Tensor]]: of (metric name, metric value) tuples. This will automatically also call .reset() on the metrics. The first returned metric is the average across all structures, then come the per-structure values. """ - for d in iter(self): + for d in self: yield d.name, d.compute() # type: ignore + + def reset(self) -> None: + """ + Calls the .reset() method on all the metrics that the present object holds. + """ + for d in self: + d.reset() diff --git a/InnerEye/ML/lightning_models.py b/InnerEye/ML/lightning_models.py index deb5e80f0..58bd926c2 100644 --- a/InnerEye/ML/lightning_models.py +++ b/InnerEye/ML/lightning_models.py @@ -148,12 +148,14 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None: """ Writes all training or validation metrics that were aggregated over the epoch to the loggers. """ - dice = list((self.train_dice if is_training else self.val_dice).compute_all()) - for name, value in dice: + dice = self.train_dice if is_training else self.val_dice + for name, value in dice.compute_all(): self.log(name, value) - voxel_count = list((self.train_voxels if is_training else self.val_voxels).compute_all()) - for name, value in voxel_count: + dice.reset() + voxel_count = self.train_voxels if is_training else self.val_voxels + for name, value in voxel_count.compute_all(): self.log(name, value) + voxel_count.reset() super().training_or_validation_epoch_end(is_training=is_training) @@ -303,8 +305,9 @@ def compute_and_log_metrics(self, if masked is not None: _logits = masked.model_outputs.data _posteriors = self.logits_to_posterior(_logits) - # Image encoders already prepare images in float16, but the labels are not yet in that dtype - _labels = masked.labels.data.to(dtype=_posteriors.dtype) + # Classification metrics expect labels as integers, but they are float throughout the rest of the code + labels_dtype = torch.int if self.is_classification_model else _posteriors.dtype + _labels = masked.labels.data.to(dtype=labels_dtype) _subject_ids = masked.subject_ids assert _subject_ids is not None for metric in metric_list: @@ -339,13 +342,14 @@ def training_or_validation_epoch_end(self, is_training: bool) -> None: for metric in metric_list: if metric.has_predictions: # Sequence models can have no predictions at all for particular positions, depending on the data. - # Hence, only log if anything really has been accumula + # Hence, only log if anything has been accumulated. self.log(name=prefix + metric.name + target_suffix, value=metric.compute()) + metric.reset() logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger logger.flush() super().training_or_validation_epoch_end(is_training) - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: # type: ignore """ For sequence models, transfer the nested lists of items to the given GPU device. For all other models, this relies on the superclass to move the batch of data to the GPU. diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index 5e6b72d40..4532c32ad 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -331,13 +331,16 @@ def store_epoch_metrics(metrics: DictStrFloat, epoch: int, file_logger: DataframeLogger) -> None: """ - Writes all metrics into a CSV file, with an additional columns for epoch number. + Writes all metrics (apart from ones that measure run time) into a CSV file, + with an additional columns for epoch number. :param file_logger: An instance of DataframeLogger, for logging results to csv. :param epoch: The epoch corresponding to the results. :param metrics: The metrics of the specified epoch, averaged along its batches. """ logger_row = {} for key, value in metrics.items(): + if key == MetricType.SECONDS_PER_BATCH.value or key == MetricType.SECONDS_PER_EPOCH.value: + continue if key in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys(): logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[key].value] = value else: diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 530be6c14..72f9cd3bd 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -18,6 +18,7 @@ from InnerEye.Common.metrics_constants import TRAIN_PREFIX, VALIDATION_PREFIX from InnerEye.Common.resource_monitor import ResourceMonitor from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, cleanup_checkpoint_folder +from InnerEye.ML.config import SegmentationModelBase from InnerEye.ML.deep_learning_config import VISUALIZATION_FOLDER from InnerEye.ML.lightning_base import TrainingAndValidationDataLightning from InnerEye.ML.lightning_helpers import create_lightning_model @@ -186,8 +187,9 @@ def model_train(config: ModelConfigBase, ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization") # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't # want training to depend on how many patients we visualized, and hence set the random seed again right after. - with logging_section("Visualizing the effect of sampling random crops for training"): - visualize_random_crops_for_dataset(config) + if isinstance(config, SegmentationModelBase): + with logging_section("Visualizing the effect of sampling random crops for training"): + visualize_random_crops_for_dataset(config) # Print out a detailed breakdown of layers, memory consumption and time. generate_and_print_model_summary(config, lightning_model.model) diff --git a/InnerEye/ML/models/layers/weight_standardization.py b/InnerEye/ML/models/layers/weight_standardization.py index a276125ed..79fa6590e 100644 --- a/InnerEye/ML/models/layers/weight_standardization.py +++ b/InnerEye/ML/models/layers/weight_standardization.py @@ -51,4 +51,4 @@ def standardize(weights: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore standardized_weights = WeightStandardizedConv2d.standardize(self.weight) - return self._conv_forward(input, standardized_weights) # type: ignore + return self._conv_forward(input, standardized_weights, bias=None) # type: ignore diff --git a/InnerEye/ML/runner.py b/InnerEye/ML/runner.py index e0344e779..4dfbcf3fc 100755 --- a/InnerEye/ML/runner.py +++ b/InnerEye/ML/runner.py @@ -265,8 +265,13 @@ def run_in_situ(self) -> None: print_exception(ex, "Unable to run PyTest.") error_messages.append(f"Unable to run PyTest: {ex}") else: - try: + # Set environment variables for multi-node training if needed. + # In particular, the multi-node environment variables should NOT be set in single node + # training, otherwise this might lead to errors with the c10 distributed backend + # (https://github.com/microsoft/InnerEye-DeepLearning/issues/395) + if self.azure_config.num_nodes > 1: set_environment_variables_for_multi_node() + try: logging_to_file(self.model_config.logs_folder / LOG_FILE_NAME) try: self.create_ml_runner().run() diff --git a/InnerEye/ML/visualizers/model_summary.py b/InnerEye/ML/visualizers/model_summary.py index 6c423348f..2d3e87e4b 100644 --- a/InnerEye/ML/visualizers/model_summary.py +++ b/InnerEye/ML/visualizers/model_summary.py @@ -10,8 +10,8 @@ import numpy as np import torch -import torchprof from torch.utils.hooks import RemovableHandle +from torchprof.profile import Profile from InnerEye.Common.common_util import logging_only_to_file from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH @@ -188,7 +188,7 @@ def print_summary() -> None: # Register the forward-pass hooks, profile the model, and restore its state self.model.apply(self._register_hook) - with torchprof.Profile(self.model, use_cuda=self.use_gpu) as prof: + with Profile(self.model, use_cuda=self.use_gpu) as prof: forward_preserve_state(self.model, input_tensors) # type: ignore # Log the model summary: tensor shapes, num of parameters, memory requirement, and forward pass time diff --git a/InnerEye/ML/visualizers/patch_sampling.py b/InnerEye/ML/visualizers/patch_sampling.py index 0962f8732..3dc4cc30e 100644 --- a/InnerEye/ML/visualizers/patch_sampling.py +++ b/InnerEye/ML/visualizers/patch_sampling.py @@ -15,7 +15,6 @@ from InnerEye.ML.dataset.cropping_dataset import CroppingDataset from InnerEye.ML.dataset.full_image_dataset import FullImageDataset from InnerEye.ML.dataset.sample import Sample -from InnerEye.ML.deep_learning_config import DeepLearningConfig from InnerEye.ML.plotting import resize_and_save, scan_with_transparent_overlay from InnerEye.ML.utils import augmentation, io_util, ml_util from InnerEye.ML.utils.config_util import ModelConfigLoader @@ -108,8 +107,7 @@ def visualize_random_crops(sample: Sample, return heatmap -def visualize_random_crops_for_dataset(config: DeepLearningConfig, - output_folder: Optional[Path] = None) -> None: +def visualize_random_crops_for_dataset(config: SegmentationModelBase, output_folder: Optional[Path] = None) -> None: """ For segmentation models only: This function generates visualizations of the effect of sampling random patches for training. Visualizations are stored in both Nifti format, and as 3 PNG thumbnail files, in the output folder. @@ -117,8 +115,6 @@ def visualize_random_crops_for_dataset(config: DeepLearningConfig, :param output_folder: The folder in which the visualizations should be written. If not provided, use a subfolder "patch_sampling" in the models's default output folder """ - if not isinstance(config, SegmentationModelBase): - return dataset_splits = config.get_dataset_splits() # Load a sample using the full image data loader full_image_dataset = FullImageDataset(config, dataset_splits.train) diff --git a/Tests/Azure/test_azure_util.py b/Tests/Azure/test_azure_util.py index a8c96c9b8..a7c63784f 100644 --- a/Tests/Azure/test_azure_util.py +++ b/Tests/Azure/test_azure_util.py @@ -8,19 +8,15 @@ import pytest from azureml.core import Run -from azureml.core.conda_dependencies import CondaDependencies from azureml.core.workspace import Workspace from InnerEye.Azure.azure_config import AzureConfig, SourceConfig -from InnerEye.Azure.azure_runner import create_experiment_name, get_or_create_python_environment, \ - pytorch_version_from_conda_dependencies +from InnerEye.Azure.azure_runner import create_experiment_name, get_or_create_python_environment from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, fetch_child_runs, fetch_run, \ get_cross_validation_split_index, is_cross_validation_child_run, is_run_and_child_runs_completed, \ - merge_conda_dependencies, \ - merge_conda_files, to_azure_friendly_container_path -from InnerEye.Common import fixed_paths + merge_conda_dependencies, merge_conda_files, to_azure_friendly_container_path from InnerEye.Common.common_util import logging_to_stdout -from InnerEye.Common.fixed_paths import ENVIRONMENT_YAML_FILE_NAME, PRIVATE_SETTINGS_FILE, PROJECT_SECRETS_FILE, \ +from InnerEye.Common.fixed_paths import PRIVATE_SETTINGS_FILE, PROJECT_SECRETS_FILE, \ get_environment_yaml_file, repository_root_directory from InnerEye.Common.output_directories import OutputFolderForTests from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, get_most_recent_run, get_most_recent_run_id @@ -157,19 +153,6 @@ def test_experiment_name() -> None: assert create_experiment_name(c) == "foo" -def test_framework_version(test_output_dirs: OutputFolderForTests) -> None: - """ - Test if the Pytorch framework version can be read correctly from the current environment file. - """ - environment_file = fixed_paths.repository_root_directory(ENVIRONMENT_YAML_FILE_NAME) - assert environment_file.is_file(), "Environment file must be present" - conda_dep = CondaDependencies(conda_dependencies_file_path=environment_file) - framework = pytorch_version_from_conda_dependencies(conda_dep) - # If this fails, it is quite likely that the AzureML SDK is behind pytorch, and does not yet know about a - # new version of pytorch that we are using here. - assert framework is not None - - def get_run_and_check(run_id: str, expected: bool, workspace: Workspace) -> None: run = fetch_run(workspace, run_id) status = is_run_and_child_runs_completed(run) diff --git a/Tests/ML/models/test_scalar_model.py b/Tests/ML/models/test_scalar_model.py index 22eb5cd1a..545c0e2cb 100644 --- a/Tests/ML/models/test_scalar_model.py +++ b/Tests/ML/models/test_scalar_model.py @@ -65,8 +65,7 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N MetricType.AREA_UNDER_ROC_CURVE, MetricType.CROSS_ENTROPY, MetricType.LOSS, - # For unknown reasons, we don't get seconds_per_batch for the training data. - # MetricType.SECONDS_PER_BATCH, + MetricType.SECONDS_PER_BATCH, MetricType.SECONDS_PER_EPOCH, MetricType.SUBJECT_COUNT, ]: @@ -92,34 +91,31 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N epoch_metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / EPOCH_METRICS_FILE_NAME # Auto-format will break the long header line, hence the strange way of writing it! expected_epoch_metrics = \ - "loss,cross_entropy,accuracy_at_threshold_05,seconds_per_epoch,learning_rate," + \ + "loss,cross_entropy,accuracy_at_threshold_05,learning_rate," + \ "area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \ "false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \ "optimal_threshold,subject_count,epoch,cross_validation_split_index\n" + \ - """0.6866141557693481,0.6866141557693481,0.5,0,0.0001,1.0,1.0,0.5,0.0,0.0,0.529514,2.0,0,-1 - 0.6864652633666992,0.6864652633666992,0.5,0,9.999712322065557e-05,1.0,1.0,0.5,0.0,0.0,0.529475,2.0,1,-1 - 0.6863163113594055,0.6863162517547607,0.5,0,9.999306876841536e-05,1.0,1.0,0.5,0.0,0.0,0.529437,2.0,2,-1 - 0.6861673593521118,0.6861673593521118,0.5,0,9.998613801725043e-05,1.0,1.0,0.5,0.0,0.0,0.529399,2.0,3,-1 + """0.6866141557693481,0.6866141557693481,0.5,0.0001,1.0,1.0,0.5,0.0,0.0,0.529514,2.0,0,-1 + 0.6864652633666992,0.6864652633666992,0.5,9.999712322065557e-05,1.0,1.0,0.5,0.0,0.0,0.529475,2.0,1,-1 + 0.6863163113594055,0.6863162517547607,0.5,9.999306876841536e-05,1.0,1.0,0.5,0.0,0.0,0.529437,2.0,2,-1 + 0.6861673593521118,0.6861673593521118,0.5,9.998613801725043e-05,1.0,1.0,0.5,0.0,0.0,0.529399,2.0,3,-1 """ - # We cannot compare columns like "seconds_per_epoch" because timing will obviously vary between machines. - # Column must still be present, though. - check_log_file(epoch_metrics_path, expected_epoch_metrics, - ignore_columns=[LoggingColumns.SecondsPerEpoch.value]) + check_log_file(epoch_metrics_path, expected_epoch_metrics, ignore_columns=[]) # Check metrics.csv: This contains the per-subject per-epoch model outputs # Randomization comes out slightly different on Windows, hence only execute the test on Linux if common_util.is_windows(): return metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / SUBJECT_METRICS_FILE_NAME metrics_expected = \ - """prediction_target,epoch,subject,model_output,label,cross_validation_split_index,data_split -Default,0,S2,0.5295137763023376,1.0,-1,Train -Default,0,S4,0.5216594338417053,0.0,-1,Train -Default,1,S4,0.5214819312095642,0.0,-1,Train -Default,1,S2,0.5294750332832336,1.0,-1,Train -Default,2,S2,0.5294366478919983,1.0,-1,Train -Default,2,S4,0.5213046073913574,0.0,-1,Train -Default,3,S2,0.5293986201286316,1.0,-1,Train -Default,3,S4,0.5211275815963745,0.0,-1,Train + """epoch,subject,prediction_target,model_output,label,data_split,cross_validation_split_index +0,S2,Default,0.529514,1,Train,-1 +0,S4,Default,0.521659,0,Train,-1 +1,S4,Default,0.521482,0,Train,-1 +1,S2,Default,0.529475,1,Train,-1 +2,S4,Default,0.521305,0,Train,-1 +2,S2,Default,0.529437,1,Train,-1 +3,S2,Default,0.529399,1,Train,-1 +3,S4,Default,0.521128,0,Train,-1 """ check_log_file(metrics_path, metrics_expected, ignore_columns=[]) # Check log METRICS_FILE_NAME inside of the folder epoch_004/Train, which is written when we run model_test. @@ -134,9 +130,19 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N check_log_file(inference_metrics_path, inference_metrics_expected, ignore_columns=[]) +def _count_lines(s: str) -> int: + lines = [line for line in s.splitlines() if line.strip()] + return len(lines) + + def check_log_file(path: Path, expected_csv: str, ignore_columns: List[str]) -> None: df_expected = pd.read_csv(StringIO(expected_csv)) df_epoch_metrics_actual = pd.read_csv(path) + # Add a separate check for number of lines. Data frames with lines are exact duplicates are not caught + # as mismatches. + num_expected_lines = _count_lines(expected_csv) + num_actual_lines = _count_lines(path.read_text()) + assert num_actual_lines == num_expected_lines, "Number of lines does not match" for ignore_column in ignore_columns: assert ignore_column in df_epoch_metrics_actual, f"Column {ignore_column} will be ignored, but must still be" \ f"present in the dataframe" diff --git a/Tests/ML/test_metrics.py b/Tests/ML/test_metrics.py index 2a456113e..b30a754fb 100644 --- a/Tests/ML/test_metrics.py +++ b/Tests/ML/test_metrics.py @@ -168,15 +168,15 @@ def test_classification_metrics() -> None: metrics = classification_module._get_metrics_computers() logits = [torch.tensor([2.1972, 1.3863, 0.4055]), torch.tensor([-0.8473, 2.1972, -0.4055])] posteriors = [torch.sigmoid(logit) for logit in logits] - labels = [torch.tensor([1., 1., 0.]), torch.tensor([0., 0., 0.])] + labels = [torch.tensor([1, 1, 0]), torch.tensor([0, 0, 0])] for logit, posterior, label in zip(logits, posteriors, labels): for metric in metrics: if isinstance(metric, ScalarMetricsBase) and metric.compute_from_logits: metric.update(logit, label) else: metric.update(posterior, label) - accuracy_05, accuracy_opt, threshold, fpr, fnr, roc_auc, pr_auc, cross_entropy_with_logits = [metric.compute() for metric in - metrics] + accuracy_05, accuracy_opt, threshold, fpr, fnr, roc_auc, pr_auc, cross_entropy_with_logits = \ + [metric.compute() for metric in metrics] all_labels = torch.cat(labels).numpy() all_posteriors = torch.cat(posteriors).numpy() expected_accuracy_at_05 = np.mean((all_posteriors > 0.5) == all_labels) @@ -237,8 +237,7 @@ def test_average_without_nan() -> None: # Return value is a scalar, but should be a tensor assert torch.is_tensor(actual1) assert actual1 == expected - # .compute() has a special wrapper that calls .reset() right after calling .compute(). Hence, now it seems - # that the average has not seen any values + average.reset() assert average.count == 0 # Store the same set of values twice, we should still see the same mean average.update(torch.tensor(values)) @@ -252,6 +251,7 @@ def test_average_without_nan() -> None: # This is a weird side effect of Lightning's way of caching metric results. As long as we don't call # .update, the last computed value will be kept and returned, even though we have called .reset() already. assert average.compute() == expected + average.reset() # Update with a tensor that does not contain any values: Can't compute the average then. average.update(torch.zeros((0,))) with pytest.raises(ValueError) as ex: diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index 2427e9d93..c7ef5c055 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -35,7 +35,7 @@ from InnerEye.ML.utils.training_util import ModelTrainingResults from InnerEye.ML.visualizers.patch_sampling import PATCH_SAMPLING_FOLDER from Tests.ML.configs.DummyModel import DummyModel -from Tests.ML.util import get_default_checkpoint_handler +from Tests.ML.util import get_default_checkpoint_handler, machine_has_gpu config_path = full_ml_test_data_path() base_path = full_ml_test_data_path() @@ -73,14 +73,15 @@ def _check_patch_centers(diagnostics_per_epoch: List[np.ndarray], should_equal: assert np.array_equal(patch_centers_epoch1, diagnostic) == should_equal def _check_voxel_count(results_per_epoch: List[Dict[str, float]], - expected_voxel_count_per_epoch: List[float]) -> None: + expected_voxel_count_per_epoch: List[float], + prefix: str) -> None: assert len(results_per_epoch) == len(expected_voxel_count_per_epoch) - for (results, voxel_count) in zip(results_per_epoch, expected_voxel_count_per_epoch): + for epoch, (results, voxel_count) in enumerate(zip(results_per_epoch, expected_voxel_count_per_epoch)): # In the test data, both structures "region" and "region_1" are read from the same nifti file, hence # their voxel counts must be identical. for structure in ["region", "region_1"]: assert results[f"{MetricType.VOXEL_COUNT.value}/{structure}"] == pytest.approx(voxel_count, abs=1e-2), \ - f"Voxel count mismatch for '{structure}'" + f"{prefix} voxel count mismatch for '{structure}' epoch {epoch}" def _mean(a: List[float]) -> float: return sum(a) / len(a) @@ -100,8 +101,12 @@ def _mean_list(lists: List[List[float]]) -> List[float]: train_config.store_dataset_sample = True train_config.recovery_checkpoint_save_interval = 1 - expected_train_losses = [0.4552295, 0.4548622] - expected_val_losses = [0.4553889, 0.4553044] + if machine_has_gpu: + expected_train_losses = [0.4553468, 0.454904] + expected_val_losses = [0.4553881, 0.4553041] + else: + expected_train_losses = [0.4553469, 0.4548947] + expected_val_losses = [0.4553880, 0.4553041] loss_absolute_tolerance = 1e-6 expected_learning_rates = [train_config.l_rate, 5.3589e-4] @@ -129,10 +134,11 @@ def assert_all_close(metric: str, expected: List[float], **kwargs: Any) -> None: # Simple regression test: Voxel counts should be the same in both epochs on the validation set, # and be the same across 'region' and 'region_1' because they derive from the same Nifti files. # The following values are read off directly from the results of compute_dice_across_patches in the training loop - train_voxels = [[83014.0, 83255.0, 82946.0], [83000.0, 82881.0, 83309.0]] + # This checks that averages are computed correctly, and that metric computers are reset after each epoch. + train_voxels = [[83092.0, 83212.0, 82946.0], [83000.0, 82881.0, 83309.0]] val_voxels = [[82765.0, 83212.0], [82765.0, 83212.0]] - _check_voxel_count(model_training_result.train_results_per_epoch, _mean_list(train_voxels)) - _check_voxel_count(model_training_result.val_results_per_epoch, _mean_list(val_voxels)) + _check_voxel_count(model_training_result.train_results_per_epoch, _mean_list(train_voxels), "Train") + _check_voxel_count(model_training_result.val_results_per_epoch, _mean_list(val_voxels), "Val") actual_train_losses = model_training_result.get_training_metric(MetricType.LOSS.value) actual_val_losses = model_training_result.get_validation_metric(MetricType.LOSS.value) @@ -145,17 +151,20 @@ def assert_all_close(metric: str, expected: List[float], **kwargs: Any) -> None: tracked_metric = TrackedMetrics.Val_Loss.value[len(VALIDATION_PREFIX):] for val_result in model_training_result.val_results_per_epoch: assert tracked_metric in val_result - # The following values are read off directly from the results of compute_dice_across_patches in the training loop - train_dice_region = [[0.0, 0.0, 0.0], [0.01922884, 0.01918082, 0.07752819]] - train_dice_region1 = [[0.48280242, 0.48337635, 0.4974504], [0.5024475, 0.5007884, 0.48952717]] + + # The following values are read off directly from the results of compute_dice_across_patches in the + # training loop. Results are slightly different for CPU, hence use a larger tolerance there. + dice_tolerance = 1e-4 if machine_has_gpu else 4.5e-4 + train_dice_region = [[0.0, 0.0, 4.0282e-04], [0.0309, 0.0334, 0.0961]] + train_dice_region1 = [[0.4806, 0.4800, 0.4832], [0.4812, 0.4842, 0.4663]] # There appears to be some amount of non-determinism here: When using a tolerance of 1e-4, we get occasional # test failures on Linux in the cloud (not on Windows, not on AzureML) Unclear where it comes from. Even when # failing here, the losses match up to the expected tolerance. - assert_all_close("Dice/region", _mean_list(train_dice_region), atol=1.3e-4) - assert_all_close("Dice/region_1", _mean_list(train_dice_region1), atol=1.3e-4) + assert_all_close("Dice/region", _mean_list(train_dice_region), atol=dice_tolerance) + assert_all_close("Dice/region_1", _mean_list(train_dice_region1), atol=dice_tolerance) expected_average_dice = [_mean(train_dice_region[i] + train_dice_region1[i]) # type: ignore for i in range(len(train_dice_region))] - assert_all_close("Dice/AverageAcrossStructures", expected_average_dice, atol=1e-4) + assert_all_close("Dice/AverageAcrossStructures", expected_average_dice, atol=dice_tolerance) # check output files/directories assert train_config.outputs_folder.is_dir() @@ -186,9 +195,7 @@ def assert_all_close(metric: str, expected: List[float], **kwargs: Any) -> None: model_training_result.get_training_metric(MetricType.SECONDS_PER_EPOCH.value) model_training_result.get_validation_metric(MetricType.SECONDS_PER_EPOCH.value) model_training_result.get_validation_metric(MetricType.SECONDS_PER_BATCH.value) - # We should have time per batch also for training, but it does not appear in the logs somehow? - # Logging the metric is called, but they never make it to the logger object. - # model_training_result.get_training_metric(MetricType.SECONDS_PER_BATCH.value) + model_training_result.get_training_metric(MetricType.SECONDS_PER_BATCH.value) # Issue #372 # # Test for saving of example images diff --git a/azure_runner.yml b/azure_runner.yml index 599480ea7..103d999d7 100644 --- a/azure_runner.yml +++ b/azure_runner.yml @@ -6,8 +6,8 @@ dependencies: - python=3.7.3 - pip: - azure-mgmt-resource==10.2.0 - - azureml-sdk==1.19.0 - - azureml-tensorboard==1.19.0 + - azureml-sdk==1.23.0 + - azureml-tensorboard==1.23.0 - conda-merge==0.1.5 - gitpython==3.1.7 - numpy==1.19.1 diff --git a/environment.yml b/environment.yml index ce75cee15..1124ebdfa 100644 --- a/environment.yml +++ b/environment.yml @@ -5,15 +5,15 @@ channels: dependencies: - pip=20.1.1 - python=3.7.3 - - pytorch=1.6.0 + - pytorch=1.8.0 - python-blosc==1.7.0 - - torchvision=0.7.0 + - torchvision=0.9.0 - pip: - git+https://github.com/analysiscenter/radio.git@6d53e25#egg=radio - azure-mgmt-resource==10.2.0 - - azureml-mlflow==1.19.0 - - azureml-sdk==1.19.0 - - azureml-tensorboard==1.19.0 + - azureml-mlflow==1.23.0 + - azureml-sdk==1.23.0 + - azureml-tensorboard==1.23.0 - conda-merge==0.1.5 - dataclasses-json==0.5.2 - flake8==3.8.3 @@ -42,7 +42,7 @@ dependencies: - pytest-cov==2.10.1 - pytest-forked==1.3.0 - pytest-xdist==1.34.0 - - pytorch-lightning==1.0.6 + - pytorch-lightning==1.1.8 - rich==5.1.1 - rpdb==0.1.6 - scikit-image==0.17.2 @@ -55,4 +55,4 @@ dependencies: - tabulate==0.8.7 - tensorboard==2.3.0 - tensorboardX==2.1 - - torchprof==1.1.1 + - torchprof==1.3.3