Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Split validation and test infer config (#502)
Browse files Browse the repository at this point in the history
Split validation, test, ensemble inference flags
  • Loading branch information
JonathanTripp committed Jul 5, 2021
1 parent 43d31ce commit cab68cc
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 108 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ created.
## Upcoming

### Added
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
- ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test
jobs that run in AzureML.

### Changed
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) Renamed command line option 'perform_training_set_inference' to 'inference_on_train_set'. Replaced command line option 'perform_validation_and_test_set_inference' with the pair of options 'inference_on_val_set' and 'inference_on_test_set'.
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
- ([#497](https://github.com/microsoft/InnerEye-DeepLearning/pull/497)) Reducing the size of the code snapshot that
gets uploaded to AzureML, by skipping all test folders.
Expand Down
28 changes: 2 additions & 26 deletions InnerEye/Azure/azure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from azureml.train.hyperdrive import HyperDriveConfig
from git import Repo

from InnerEye.Azure.azure_util import fetch_run, is_offline_run_context
from InnerEye.Azure.azure_util import fetch_run, is_offline_run_context, remove_arg
from InnerEye.Azure.secrets_handling import SecretsHandling, read_all_settings
from InnerEye.Common import fixed_paths
from InnerEye.Common.generic_parsing import GenericConfig
Expand Down Expand Up @@ -324,31 +324,7 @@ def set_script_params_except_submit_flag(self) -> None:
Populates the script_param field of the present object from the arguments in sys.argv, with the exception
of the "azureml" flag.
"""
args = sys.argv[1:]
submit_flag = f"--{AZURECONFIG_SUBMIT_TO_AZUREML}"
retained_args = []
i = 0
while i < len(args):
arg = args[i]
if arg.startswith(submit_flag):
if len(arg) == len(submit_flag):
# The commandline argument is "--azureml", with something possibly following: This can either be
# "--azureml True" or "--azureml --some_other_param"
if i < (len(args) - 1):
# If the next argument starts with a "-" then assume that it does not belong to the --azureml
# flag. If there is no "-", assume it belongs to the --azureml flag, and skip both
if not args[i + 1].startswith("-"):
i = i + 1
elif arg[len(submit_flag)] == "=":
# The commandline argument is "--azureml=True" or "--azureml=False": Continue with next arg
pass
else:
# The argument list contains a flag like "--azureml_foo": Keep that.
retained_args.append(arg)
else:
retained_args.append(arg)
i = i + 1
self.script_params = retained_args
self.script_params = remove_arg(AZURECONFIG_SUBMIT_TO_AZUREML, sys.argv[1:])


@dataclass
Expand Down
40 changes: 40 additions & 0 deletions InnerEye/Azure/azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,43 @@ def step_up_directories(path: Path) -> Generator[Path, None, None]:
if parent == path:
break
path = parent


def remove_arg(arg: str, args: List[str]) -> List[str]:
"""
Remove an argument from a list of arguments. The argument list is assumed to contain
elements of the form:
"-a", "--arg1", "--arg2", "value2", or "--arg3=value"
If there is an item matching "--arg" then it will be removed from the list.
:param arg: Argument to look for.
:param args: List of arguments to scan.
:return: List of arguments with --arg removed, if present.
"""
arg_opt = f"--{arg}"
no_arg_opt = f"--no-{arg}"
retained_args = []
i = 0
while i < len(args):
arg = args[i]
if arg.startswith(arg_opt):
if len(arg) == len(arg_opt):
# The commandline argument is "--arg", with something possibly following: This can either be
# "--arg_opt value" or "--arg_opt --some_other_param"
if i < (len(args) - 1):
# If the next argument starts with a "-" then assume that it does not belong to the --arg
# argument. If there is no "-", assume it belongs to the --arg_opt argument, and skip both
if not args[i + 1].startswith("-"):
i = i + 1
elif arg[len(arg_opt)] == "=":
# The commandline argument is "--arg=value": Continue with next arg
pass
else:
# The argument list contains an argument like "--arg_other_param": Keep that.
retained_args.append(arg)
elif arg == no_arg_opt:
pass
else:
retained_args.append(arg)
i = i + 1
return retained_args
5 changes: 3 additions & 2 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def setup(self) -> None:
dataset_path=self.local_dataset,
batch_size=self.ssl_training_batch_size)})
self.data_module: InnerEyeDataModuleTypes = self.get_data_module()
self.perform_validation_and_test_set_inference = False
if self.number_of_cross_validation_splits > 1:
self.inference_on_val_set = False
self.inference_on_test_set = False
if self.perform_cross_validation:
raise NotImplementedError("Cross-validation logic is not implemented for this module.")

def _load_config(self) -> None:
Expand Down
82 changes: 69 additions & 13 deletions InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
from __future__ import annotations

import logging
from enum import Enum, unique
from pathlib import Path
from typing import Any, Dict, List, Optional

import param
from enum import Enum, unique
from pandas import DataFrame
from param import Parameterized
from pathlib import Path
from typing import Any, Dict, List, Optional

from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, RUN_CONTEXT, is_offline_run_context
from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import is_windows
from InnerEye.Common.common_util import ModelProcessing, is_windows
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, DEFAULT_LOGS_DIR_NAME
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.Common.type_annotations import PathOrString, TupleFloat2
Expand Down Expand Up @@ -199,14 +198,24 @@ class WorkflowParams(param.Parameterized):
cross_validation_split_index: int = param.Integer(DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, bounds=(-1, None),
doc="The index of the cross validation fold this model is "
"associated with when performing k-fold cross validation")
perform_training_set_inference: bool = \
param.Boolean(False,
doc="If True, run full image inference on the training set at the end of training. If False and "
"perform_validation_and_test_set_inference is True (default), only run inference on "
"validation and test set. If both flags are False do not run inference.")
perform_validation_and_test_set_inference: bool = \
param.Boolean(True,
doc="If True (default), run full image inference on validation and test set after training.")
inference_on_train_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on training set after training.")
inference_on_val_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on validation set after training.")
inference_on_test_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on test set after training.")
ensemble_inference_on_train_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on the training set after ensemble training.")
ensemble_inference_on_val_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on validation set after ensemble training.")
ensemble_inference_on_test_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on test set after ensemble training.")
weights_url: str = param.String(doc="If provided, a url from which weights will be downloaded and used for model "
"initialization.")
local_weights_path: Optional[Path] = param.ClassSelector(class_=Path,
Expand Down Expand Up @@ -254,6 +263,53 @@ def validate(self) -> None:
f"found number_of_cross_validation_splits = {self.number_of_cross_validation_splits} "
f"and cross_validation_split_index={self.cross_validation_split_index}")

""" Defaults for when to run inference in the absence of any command line switches. """
INFERENCE_DEFAULTS: Dict[ModelProcessing, Dict[ModelExecutionMode, bool]] = {
ModelProcessing.DEFAULT: {
ModelExecutionMode.TRAIN: False,
ModelExecutionMode.TEST: True,
ModelExecutionMode.VAL: True,
},
ModelProcessing.ENSEMBLE_CREATION: {
ModelExecutionMode.TRAIN: False,
ModelExecutionMode.TEST: True,
ModelExecutionMode.VAL: False,
}
}

def inference_options(self) -> Dict[ModelProcessing, Dict[ModelExecutionMode, Optional[bool]]]:
"""
Return a mapping from ModelProcesing and ModelExecutionMode to command line switch.
:return: Command line switch for each combination of ModelProcessing and ModelExecutionMode.
"""
return {
ModelProcessing.DEFAULT: {
ModelExecutionMode.TRAIN: self.inference_on_train_set,
ModelExecutionMode.TEST: self.inference_on_test_set,
ModelExecutionMode.VAL: self.inference_on_val_set,
},
ModelProcessing.ENSEMBLE_CREATION: {
ModelExecutionMode.TRAIN: self.ensemble_inference_on_train_set,
ModelExecutionMode.TEST: self.ensemble_inference_on_test_set,
ModelExecutionMode.VAL: self.ensemble_inference_on_val_set,
}
}

def inference_on_set(self, model_proc: ModelProcessing, data_split: ModelExecutionMode) -> bool:
"""
Returns True if inference is required for this model_proc and data_split.
:param model_proc: Whether we are testing an ensemble or single model.
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
:return: True if inference required.
"""
inference_option = self.inference_options()[model_proc][data_split]
if inference_option is not None:
return inference_option

return WorkflowParams.INFERENCE_DEFAULTS[model_proc][data_split]

@property
def is_offline_run(self) -> bool:
"""
Expand Down
77 changes: 38 additions & 39 deletions InnerEye/ML/run_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import pandas as pd
from pytorch_lightning.core.datamodule import LightningDataModule
import stopit
import torch.multiprocessing
from azureml._restclient.constants import RunStatus
Expand Down Expand Up @@ -120,19 +121,16 @@ def download_dataset(azure_dataset_id: str,
return expected_dataset_path


def log_metrics(val_metrics: Optional[InferenceMetricsForSegmentation],
test_metrics: Optional[InferenceMetricsForSegmentation],
train_metrics: Optional[InferenceMetricsForSegmentation],
def log_metrics(metrics: Dict[ModelExecutionMode, InferenceMetrics],
run_context: Run) -> None:
"""
Log metrics for each split to the provided run, or the current run context if None provided
:param val_metrics: Inference results for the validation split
:param test_metrics: Inference results for the test split
:param train_metrics: Inference results for the train split
:param metrics: Dictionary of inference results for each split.
:param run_context: Run for which to log the metrics to, use the current run context if None provided
"""
for split in [x for x in [val_metrics, test_metrics, train_metrics] if x]:
split.log_metrics(run_context)
for split in metrics.values():
if isinstance(split, InferenceMetricsForSegmentation):
split.log_metrics(run_context)


class MLRunner:
Expand Down Expand Up @@ -390,7 +388,7 @@ def run(self) -> None:

# If this is an cross validation run, and the present run is child run 0, then wait for the sibling
# runs, build the ensemble model, and write a report for that.
if self.container.number_of_cross_validation_splits > 0:
if self.container.perform_cross_validation:
should_wait_for_other_child_runs = (not self.is_offline_run) and \
self.container.cross_validation_split_index == 0
if should_wait_for_other_child_runs:
Expand Down Expand Up @@ -420,10 +418,24 @@ def is_normal_run_or_crossval_child_0(self) -> bool:
"""
Returns True if the present run is a non-crossvalidation run, or child run 0 of a crossvalidation run.
"""
if self.container.number_of_cross_validation_splits > 0:
if self.container.perform_cross_validation:
return self.container.cross_validation_split_index == 0
return True

@staticmethod
def lightning_data_module_dataloaders(data: LightningDataModule) -> Dict[ModelExecutionMode, Callable]:
"""
Given a lightning data module, return a dictionary of dataloader for each model execution mode.
:param data: Lightning data module.
:return: Data loader for each model execution mode.
"""
return {
ModelExecutionMode.TEST: data.test_dataloader,
ModelExecutionMode.VAL: data.val_dataloader,
ModelExecutionMode.TRAIN: data.train_dataloader
}

def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> None:
"""
Run inference on the test set for all models that are specified via a LightningContainer.
Expand All @@ -439,11 +451,10 @@ def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> No
# Read the data modules before changing the working directory, in case the code relies on relative paths
data = self.container.get_inference_data_module()
dataloaders: List[Tuple[DataLoader, ModelExecutionMode]] = []
if self.container.perform_validation_and_test_set_inference:
dataloaders.append((data.test_dataloader(), ModelExecutionMode.TEST)) # type: ignore
dataloaders.append((data.val_dataloader(), ModelExecutionMode.VAL)) # type: ignore
if self.container.perform_training_set_inference:
dataloaders.append((data.train_dataloader(), ModelExecutionMode.TRAIN)) # type: ignore
data_dataloaders = MLRunner.lightning_data_module_dataloaders(data)
for data_split, dataloader in data_dataloaders.items():
if self.container.inference_on_set(ModelProcessing.DEFAULT, data_split):
dataloaders.append((dataloader(), data_split))
checkpoint = load_checkpoint(checkpoint_paths[0], use_gpu=self.container.use_gpu)
lightning_model.load_state_dict(checkpoint['state_dict'])
lightning_model.eval()
Expand Down Expand Up @@ -491,8 +502,8 @@ def run_inference(self, checkpoint_handler: CheckpointHandler,
"""

# run full image inference on existing or newly trained model on the training, and testing set
test_metrics, val_metrics, _ = self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
model_proc=model_proc)
self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
model_proc=model_proc)

self.try_compare_scores_against_baselines(model_proc)

Expand Down Expand Up @@ -752,37 +763,25 @@ def copy_file(source: Path, destination_file: str) -> None:
def model_inference_train_and_test(self,
checkpoint_handler: CheckpointHandler,
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> \
Tuple[Optional[InferenceMetrics], Optional[InferenceMetrics], Optional[InferenceMetrics]]:
train_metrics = None
val_metrics = None
test_metrics = None
Dict[ModelExecutionMode, InferenceMetrics]:
metrics: Dict[ModelExecutionMode, InferenceMetrics] = {}

config = self.innereye_config

def run_model_test(data_split: ModelExecutionMode) -> Optional[InferenceMetrics]:
return model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler, # type: ignore
model_proc=model_proc)

if config.perform_validation_and_test_set_inference:
# perform inference on test set
test_metrics = run_model_test(ModelExecutionMode.TEST)
# perform inference on validation set (not for ensemble as current val is in the training fold
# for at least one of the models).
if model_proc != ModelProcessing.ENSEMBLE_CREATION:
val_metrics = run_model_test(ModelExecutionMode.VAL)

if config.perform_training_set_inference:
# perform inference on training set if required
train_metrics = run_model_test(ModelExecutionMode.TRAIN)
for data_split in ModelExecutionMode:
if self.container.inference_on_set(model_proc, data_split):
opt_metrics = model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler,
model_proc=model_proc)
if opt_metrics is not None:
metrics[data_split] = opt_metrics

# log the metrics to AzureML experiment if possible. When doing ensemble runs, log to the Hyperdrive parent run,
# so that we get the metrics of child run 0 and the ensemble separated.
if config.is_segmentation_model and not self.is_offline_run:
run_for_logging = PARENT_RUN_CONTEXT if model_proc.ENSEMBLE_CREATION else RUN_CONTEXT
log_metrics(val_metrics=val_metrics, test_metrics=test_metrics, # type: ignore
train_metrics=train_metrics, run_context=run_for_logging) # type: ignore
log_metrics(metrics=metrics, run_context=run_for_logging) # type: ignore

return test_metrics, val_metrics, train_metrics
return metrics

@stopit.threading_timeoutable()
def wait_for_runs_to_finish(self, delay: int = 60) -> None:
Expand Down
2 changes: 1 addition & 1 deletion Tests/ML/configs/lightning_test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class DummyContainerWithModel(LightningContainer):

def __init__(self) -> None:
super().__init__()
self.perform_training_set_inference = True
self.inference_on_train_set = True
self.num_epochs = 50
self.l_rate = 1e-1

Expand Down
5 changes: 3 additions & 2 deletions Tests/ML/models/test_scalar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,9 @@ def test_run_ml_with_segmentation_model(test_output_dirs: OutputFolderForTests)
# This is for a bug in an earlier version of the code where the wrong execution mode was used to
# compute the expected mask size at training time.
config.test_crop_size = (75, 75, 75)
config.perform_training_set_inference = False
config.perform_validation_and_test_set_inference = True
config.inference_on_train_set = False
config.inference_on_val_set = True
config.inference_on_test_set = True
config.set_output_to(test_output_dirs.root_dir)
azure_config = get_default_azure_config()
azure_config.train = True
Expand Down
Loading

0 comments on commit cab68cc

Please sign in to comment.