Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable building an ensemble model from the cross validation checkpoints of a BYO Lightning model #529

Closed
wants to merge 101 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
1250de0
Use registered model for inference
Shruthi42 Jun 23, 2021
efc34fb
Merge branch 'main' into shbannur/load_registered_models
Shruthi42 Jun 23, 2021
f94b2e9
Bug fix
Shruthi42 Jun 23, 2021
7430785
Fix tests
Shruthi42 Jun 23, 2021
229b7ea
Fix tests
Shruthi42 Jun 23, 2021
16a09c3
mypy
Shruthi42 Jun 23, 2021
ee957d8
Fix tests
Shruthi42 Jun 23, 2021
596efee
Add tests
Shruthi42 Jun 24, 2021
2c2160f
Fix tests
Shruthi42 Jun 24, 2021
fe6fa93
Fix tests
Shruthi42 Jun 24, 2021
45895e7
Add test
Shruthi42 Jun 24, 2021
41f1b48
Fix tests
Shruthi42 Jun 24, 2021
f6bb7a2
Fix test
Shruthi42 Jul 5, 2021
3e3b069
Remove unnecessary function
Shruthi42 Jul 5, 2021
c0de1e6
Update tests
Shruthi42 Jul 5, 2021
0889a88
Flake8
Shruthi42 Jul 5, 2021
f98b22e
Fix tests
Shruthi42 Jul 5, 2021
d73f769
mypy
Shruthi42 Jul 5, 2021
f4dfbe7
Merge branch 'main' into shbannur/load_registered_models
Shruthi42 Jul 5, 2021
2767e18
Loosening multiple checkpoint check in run_inference_for_lightning_mo…
Jul 6, 2021
8a63e0a
WiP very scrappy!
Jul 7, 2021
116e566
WiP more mess
Jul 8, 2021
7107e27
WiP: bones of test class
Jul 8, 2021
d0c7724
Refactoring run_inference_for_lightning_models
Jul 9, 2021
6735d2c
mypy fixes
Jul 9, 2021
585511b
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
Jul 9, 2021
4df8c09
WiP annotations and test
Jul 10, 2021
05a67dc
WiP: saving mid task for lunch
dumbledad Jul 10, 2021
8d93907
Correcting GPU -> CPU typo in comment
dumbledad Jul 11, 2021
e7ba7f1
Method can be static
dumbledad Jul 11, 2021
e4ebfcb
WiP fiddling
dumbledad Jul 11, 2021
7625a48
Example ensemble from InnerEyeInference
dumbledad Jul 11, 2021
f5b288c
flake8 and mypy fixes
dumbledad Jul 11, 2021
a79673f
mypy fixes
dumbledad Jul 11, 2021
3e9ec7e
tidying unused parameters
Jul 11, 2021
559a717
WiP simple temp test for train/test ensemble
Jul 11, 2021
f7be221
Renaming InnerEyeInference methods
Jul 12, 2021
a3179a8
tidy up
Jul 12, 2021
f7bdc7f
Matching new naming
Jul 12, 2021
939ec5a
renaming params
Jul 12, 2021
f206057
naming
Jul 12, 2021
d71f498
Unit test WiP
Jul 12, 2021
6987f99
don't be strict with state_dict
dumbledad Jul 12, 2021
f6e025f
first unit test takes shape
dumbledad Jul 12, 2021
500bf11
Unit test works, but doesn't check much
dumbledad Jul 12, 2021
0d97fe9
renaming unit test
Jul 13, 2021
b204f08
Merge branch 'main' into shbannur/load_registered_models
Shruthi42 Jul 13, 2021
dd17d78
Change docstring
Shruthi42 Jul 13, 2021
388e0a8
Update CHANGELOG.md
Shruthi42 Jul 13, 2021
d727fe1
Rename
Shruthi42 Jul 13, 2021
fa7a6e4
Fix test
Shruthi42 Jul 13, 2021
581c6a9
WiP ensemble unit test with value check
Jul 13, 2021
1134c0f
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
Jul 14, 2021
e35db5b
Address PR comments
Shruthi42 Jul 14, 2021
9093b7e
Use list of pytest markers
Shruthi42 Jul 14, 2021
bf072c0
Move model_id to WorkflowParams
Shruthi42 Jul 14, 2021
5a76cc1
missed some name changes
Jul 14, 2021
2d75d24
WiP swapping back to checkpoints not accruing child runs
Jul 14, 2021
e064483
Refactor extra_downloaded_run_id
Shruthi42 Jul 14, 2021
168eb29
unit test working
Jul 14, 2021
838bb48
Update documentation and argparser
Shruthi42 Jul 14, 2021
0f3690c
flake & mypy
Jul 14, 2021
9c7f6b4
Revert changes to generic_parsing
Shruthi42 Jul 14, 2021
d612c6e
Update documentation
Shruthi42 Jul 14, 2021
6861178
Flake8 and mypy
Shruthi42 Jul 14, 2021
98c7683
Shruthi's changes to run_ml
Jul 14, 2021
54c845e
Merge branch 'shbannur/load_registered_models' into timregan/527-ense…
Jul 14, 2021
48aca37
WiP abstracting ensemble inference
Jul 14, 2021
d3d8477
WiP
dumbledad Jul 15, 2021
a4ea25a
Ensemble inference base
dumbledad Jul 15, 2021
5e7f0b1
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
dumbledad Jul 15, 2021
a71fa90
Ended up with changes to 2 files I did not touch!
dumbledad Jul 15, 2021
0874a74
Restoring (and fixing) run_ml changes
dumbledad Jul 16, 2021
88bb96d
mypy
dumbledad Jul 16, 2021
424b2e7
WiP
dumbledad Jul 16, 2021
4903b96
run_ml unit test v1
Jul 16, 2021
547ed7d
WiP pre pruning ensemble stuff
Jul 16, 2021
538a7a4
refactored to avoid recursion blow-up
Jul 16, 2021
002886c
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
dumbledad Jul 17, 2021
c6ca034
additional comments and remove inheritance
dumbledad Jul 17, 2021
484fe4a
removing duplicated unit test
dumbledad Jul 17, 2021
b771137
more comments
dumbledad Jul 17, 2021
6d57d08
test tidy
dumbledad Jul 17, 2021
6a9fd75
flake fixes
dumbledad Jul 17, 2021
cea51e4
WiP
dumbledad Jul 17, 2021
9fc696a
on_ensemble_inference_start needn't call down
dumbledad Jul 17, 2021
b947452
Old WiP changes
dumbledad Jul 17, 2021
f7d22f4
Adding HelloEnsembleInference
dumbledad Jul 17, 2021
af079ae
run_ml changes with parameter
dumbledad Jul 17, 2021
38b12b7
Wi{ on mypy and tidy pre unit test fix
dumbledad Jul 18, 2021
c255774
mypy fixes
dumbledad Jul 18, 2021
2db1857
import fix so test discovery works
dumbledad Jul 18, 2021
21044e6
Fixing test discovery
dumbledad Jul 18, 2021
4f10169
WiP fixing unit test
dumbledad Jul 18, 2021
e0e31e6
Unit test works
dumbledad Jul 18, 2021
1d23735
file system test
dumbledad Jul 18, 2021
1486bff
Adding register and actually building ensemble
Jul 19, 2021
ff8fef1
Removed call to innnereye_config
Jul 19, 2021
ddb3c9f
unit test fix
dumbledad Jul 22, 2021
e71d49b
fix for the reference error on AzureML
dumbledad Jul 23, 2021
db18fdd
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
dumbledad Jul 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use registered model for inference
  • Loading branch information
Shruthi42 committed Jun 23, 2021
commit 1250de038492201b0fd4ad97ad3841059701d3cb
6 changes: 4 additions & 2 deletions InnerEye/Azure/azure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class AzureConfig(GenericConfig):
pytest_mark: str = param.String(doc="If provided, run pytest instead of model training. pytest will only "
"run the tests that have the mark given in this argument "
"('--pytest_mark gpu' will run all tests marked with 'pytest.mark.gpu')")
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id'"
" to use for inference or recovering a model training run.")
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id' "
"to use for recovering a model training run or to register a model.")
pretraining_run_recovery_id: str = param.String(default=None,
allow_None=True,
doc="Extra run recovery id to download checkpoints from,"
Expand Down Expand Up @@ -122,6 +122,8 @@ class AzureConfig(GenericConfig):
_workspace: Workspace = param.ClassSelector(class_=Workspace,
doc="The cached workspace object that has been created in the first"
"call to get_workspace")
model_id: str = param.String(doc="A model id string in the form 'model name:version' "
"to use a registered model for inference.")

def __init__(self, **params: Any) -> None:
super().__init__(**params)
Expand Down
3 changes: 0 additions & 3 deletions InnerEye/Common/fixed_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
# The folder at the project root directory that holds datasets for local execution.
DATASETS_DIR_NAME = "datasets"

# Points to a folder at the project root directory that holds model weights downloaded from URLs.
MODEL_WEIGHTS_DIR_NAME = "modelweights"

ML_RELATIVE_SOURCE_PATH = os.path.join("ML")
ML_RELATIVE_RUNNER_PATH = os.path.join(ML_RELATIVE_SOURCE_PATH, "runner.py")
ML_FULL_SOURCE_FOLDER_PATH = str(repository_root_directory() / ML_RELATIVE_SOURCE_PATH)
Expand Down
16 changes: 7 additions & 9 deletions InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
EXTRA_RUN_SUBFOLDER = "extra_run_id"

ARGS_TXT = "args.txt"
WEIGHTS_FILE = "weights.pth"


@unique
Expand Down Expand Up @@ -207,13 +206,12 @@ class WorkflowParams(param.Parameterized):
perform_validation_and_test_set_inference: bool = \
param.Boolean(True,
doc="If True (default), run full image inference on validation and test set after training.")
weights_url: str = param.String(doc="If provided, a url from which weights will be downloaded and used for model "
"initialization.")
local_weights_path: Optional[Path] = param.ClassSelector(class_=Path,
default=None,
allow_None=True,
doc="The path to the weights to use for model "
"initialization, when training outside AzureML.")
checkpoint_urls: List[str] = param.List(default=[],
doc="If provided, a set of urls from which checkpoints will be downloaded"
"and used for training/inference.")
local_checkpoint_paths: List[Path] = param.List(default=[], class_=Path,
doc="A list of checkpoints paths to use for training/inference, "
"when training is running outside Azure.")
generate_report: bool = param.Boolean(default=True,
doc="If True (default), write a modelling report in HTML format. If False,"
"do not write that report.")
Expand All @@ -239,7 +237,7 @@ class WorkflowParams(param.Parameterized):
"be relative to the repository root directory.")

def validate(self) -> None:
if self.weights_url and self.local_weights_path:
if self.checkpoint_urls and self.local_checkpoint_paths:
raise ValueError("Cannot specify both local_weights_path and weights_url.")

if self.number_of_cross_validation_splits == 1:
Expand Down
28 changes: 10 additions & 18 deletions InnerEye/ML/model_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils import io_util, ml_util
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array
from InnerEye.ML.utils.io_util import ImageHeader, MedicalImageFileType, load_nifti_image, save_lines_to_file
from InnerEye.ML.utils.metrics_util import MetricsPerPatientWriter
Expand All @@ -47,15 +46,15 @@

def model_test(config: ModelConfigBase,
data_split: ModelExecutionMode,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> Optional[InferenceMetrics]:
"""
Runs model inference on segmentation or classification models, using a given dataset (that could be training,
test or validation set). The inference results and metrics will be stored and logged in a way that may
differ for model categories (classification, segmentation).
:param config: The configuration of the model
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
:param checkpoint_paths: Checkpoint paths initialize model.
:param model_proc: whether we are testing an ensemble or single model; this affects where results are written.
:return: The metrics that the model achieved on the given data set, or None if the data set is empty.
"""
Expand All @@ -67,17 +66,19 @@ def model_test(config: ModelConfigBase,
"and additional data loaders are likely to block.")
return None
with logging_section(f"Running {model_proc.value} model on {data_split.name.lower()} set"):
if not checkpoint_paths:
raise ValueError("There were no checkpoints available for model testing.")
if isinstance(config, SegmentationModelBase):
return segmentation_model_test(config, data_split, checkpoint_handler, model_proc)
return segmentation_model_test(config, data_split, checkpoint_paths, model_proc)
if isinstance(config, ScalarModelBase):
return classification_model_test(config, data_split, checkpoint_handler, model_proc,
return classification_model_test(config, data_split, checkpoint_paths, model_proc,
config.cross_validation_split_index)
raise ValueError(f"There is no testing code for models of type {type(config)}")


def segmentation_model_test(config: SegmentationModelBase,
data_split: ModelExecutionMode,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> InferenceMetricsForSegmentation:
"""
The main testing loop for segmentation models.
Expand All @@ -88,18 +89,13 @@ def segmentation_model_test(config: SegmentationModelBase,
:param model_proc: whether we are testing an ensemble or single model
:return: InferenceMetric object that contains metrics related for all of the checkpoint epochs.
"""
checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test()

if not checkpoints_to_test:
raise ValueError("There were no checkpoints available for model testing.")

epoch_results_folder = config.outputs_folder / get_best_epoch_results_path(data_split, model_proc)
# save the datasets.csv used
config.write_dataset_files(root=epoch_results_folder)
epoch_and_split = f"{data_split.value} set"
epoch_dice_per_image = segmentation_model_test_epoch(config=copy.deepcopy(config),
data_split=data_split,
checkpoint_paths=checkpoints_to_test,
checkpoint_paths=checkpoint_paths,
results_folder=epoch_results_folder,
epoch_and_split=epoch_and_split)
if epoch_dice_per_image is None:
Expand Down Expand Up @@ -395,7 +391,7 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \

def classification_model_test(config: ScalarModelBase,
data_split: ModelExecutionMode,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing,
cross_val_split_index: int) -> InferenceMetricsForClassification:
"""
Expand All @@ -404,16 +400,12 @@ def classification_model_test(config: ScalarModelBase,
:param config: The model configuration.
:param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
used mainly in model evaluation using different dataset splits.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
:param checkpoint_paths: Checkpoint paths to initialize model
:param model_proc: whether we are testing an ensemble or single model
:return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
"""
posthoc_label_transform = config.get_posthoc_label_transform()

checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
if not checkpoint_paths:
raise ValueError("There were no checkpoints available for model testing.")

pipeline = create_inference_pipeline(config=config,
checkpoint_paths=checkpoint_paths)
if pipeline is None:
Expand Down
7 changes: 2 additions & 5 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger
from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \
get_subject_output_file_per_rank
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler

TEMP_PREFIX = "temp/"

Expand Down Expand Up @@ -215,23 +214,21 @@ def start_resource_monitor(config: LightningContainer) -> ResourceMonitor:
return resource_monitor


def model_train(checkpoint_handler: CheckpointHandler,
def model_train(checkpoint_path: Path,
container: LightningContainer,
num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]:
"""
The main training loop. It creates the Pytorch model based on the configuration options passed in,
creates a Pytorch Lightning trainer, and trains the model.
If a checkpoint was specified, then it loads the checkpoint before resuming training.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
:param checkpoint_path: Checkpoint path for model initialization
:param num_nodes: The number of nodes to use in distributed training.
:param container: A container object that holds the training data in PyTorch Lightning format
and the model to train.
:return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting
the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
fitting other models.
"""
# Get the path to the checkpoint to recover from
checkpoint_path = checkpoint_handler.get_recovery_path_train()
lightning_model = container.model

resource_monitor: Optional[ResourceMonitor] = None
Expand Down
Loading