Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable building an ensemble model from the cross validation checkpoints of a BYO Lightning model #529

Closed
wants to merge 101 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
1250de0
Use registered model for inference
Shruthi42 Jun 23, 2021
efc34fb
Merge branch 'main' into shbannur/load_registered_models
Shruthi42 Jun 23, 2021
f94b2e9
Bug fix
Shruthi42 Jun 23, 2021
7430785
Fix tests
Shruthi42 Jun 23, 2021
229b7ea
Fix tests
Shruthi42 Jun 23, 2021
16a09c3
mypy
Shruthi42 Jun 23, 2021
ee957d8
Fix tests
Shruthi42 Jun 23, 2021
596efee
Add tests
Shruthi42 Jun 24, 2021
2c2160f
Fix tests
Shruthi42 Jun 24, 2021
fe6fa93
Fix tests
Shruthi42 Jun 24, 2021
45895e7
Add test
Shruthi42 Jun 24, 2021
41f1b48
Fix tests
Shruthi42 Jun 24, 2021
f6bb7a2
Fix test
Shruthi42 Jul 5, 2021
3e3b069
Remove unnecessary function
Shruthi42 Jul 5, 2021
c0de1e6
Update tests
Shruthi42 Jul 5, 2021
0889a88
Flake8
Shruthi42 Jul 5, 2021
f98b22e
Fix tests
Shruthi42 Jul 5, 2021
d73f769
mypy
Shruthi42 Jul 5, 2021
f4dfbe7
Merge branch 'main' into shbannur/load_registered_models
Shruthi42 Jul 5, 2021
2767e18
Loosening multiple checkpoint check in run_inference_for_lightning_mo…
Jul 6, 2021
8a63e0a
WiP very scrappy!
Jul 7, 2021
116e566
WiP more mess
Jul 8, 2021
7107e27
WiP: bones of test class
Jul 8, 2021
d0c7724
Refactoring run_inference_for_lightning_models
Jul 9, 2021
6735d2c
mypy fixes
Jul 9, 2021
585511b
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
Jul 9, 2021
4df8c09
WiP annotations and test
Jul 10, 2021
05a67dc
WiP: saving mid task for lunch
dumbledad Jul 10, 2021
8d93907
Correcting GPU -> CPU typo in comment
dumbledad Jul 11, 2021
e7ba7f1
Method can be static
dumbledad Jul 11, 2021
e4ebfcb
WiP fiddling
dumbledad Jul 11, 2021
7625a48
Example ensemble from InnerEyeInference
dumbledad Jul 11, 2021
f5b288c
flake8 and mypy fixes
dumbledad Jul 11, 2021
a79673f
mypy fixes
dumbledad Jul 11, 2021
3e9ec7e
tidying unused parameters
Jul 11, 2021
559a717
WiP simple temp test for train/test ensemble
Jul 11, 2021
f7be221
Renaming InnerEyeInference methods
Jul 12, 2021
a3179a8
tidy up
Jul 12, 2021
f7bdc7f
Matching new naming
Jul 12, 2021
939ec5a
renaming params
Jul 12, 2021
f206057
naming
Jul 12, 2021
d71f498
Unit test WiP
Jul 12, 2021
6987f99
don't be strict with state_dict
dumbledad Jul 12, 2021
f6e025f
first unit test takes shape
dumbledad Jul 12, 2021
500bf11
Unit test works, but doesn't check much
dumbledad Jul 12, 2021
0d97fe9
renaming unit test
Jul 13, 2021
b204f08
Merge branch 'main' into shbannur/load_registered_models
Shruthi42 Jul 13, 2021
dd17d78
Change docstring
Shruthi42 Jul 13, 2021
388e0a8
Update CHANGELOG.md
Shruthi42 Jul 13, 2021
d727fe1
Rename
Shruthi42 Jul 13, 2021
fa7a6e4
Fix test
Shruthi42 Jul 13, 2021
581c6a9
WiP ensemble unit test with value check
Jul 13, 2021
1134c0f
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
Jul 14, 2021
e35db5b
Address PR comments
Shruthi42 Jul 14, 2021
9093b7e
Use list of pytest markers
Shruthi42 Jul 14, 2021
bf072c0
Move model_id to WorkflowParams
Shruthi42 Jul 14, 2021
5a76cc1
missed some name changes
Jul 14, 2021
2d75d24
WiP swapping back to checkpoints not accruing child runs
Jul 14, 2021
e064483
Refactor extra_downloaded_run_id
Shruthi42 Jul 14, 2021
168eb29
unit test working
Jul 14, 2021
838bb48
Update documentation and argparser
Shruthi42 Jul 14, 2021
0f3690c
flake & mypy
Jul 14, 2021
9c7f6b4
Revert changes to generic_parsing
Shruthi42 Jul 14, 2021
d612c6e
Update documentation
Shruthi42 Jul 14, 2021
6861178
Flake8 and mypy
Shruthi42 Jul 14, 2021
98c7683
Shruthi's changes to run_ml
Jul 14, 2021
54c845e
Merge branch 'shbannur/load_registered_models' into timregan/527-ense…
Jul 14, 2021
48aca37
WiP abstracting ensemble inference
Jul 14, 2021
d3d8477
WiP
dumbledad Jul 15, 2021
a4ea25a
Ensemble inference base
dumbledad Jul 15, 2021
5e7f0b1
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
dumbledad Jul 15, 2021
a71fa90
Ended up with changes to 2 files I did not touch!
dumbledad Jul 15, 2021
0874a74
Restoring (and fixing) run_ml changes
dumbledad Jul 16, 2021
88bb96d
mypy
dumbledad Jul 16, 2021
424b2e7
WiP
dumbledad Jul 16, 2021
4903b96
run_ml unit test v1
Jul 16, 2021
547ed7d
WiP pre pruning ensemble stuff
Jul 16, 2021
538a7a4
refactored to avoid recursion blow-up
Jul 16, 2021
002886c
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
dumbledad Jul 17, 2021
c6ca034
additional comments and remove inheritance
dumbledad Jul 17, 2021
484fe4a
removing duplicated unit test
dumbledad Jul 17, 2021
b771137
more comments
dumbledad Jul 17, 2021
6d57d08
test tidy
dumbledad Jul 17, 2021
6a9fd75
flake fixes
dumbledad Jul 17, 2021
cea51e4
WiP
dumbledad Jul 17, 2021
9fc696a
on_ensemble_inference_start needn't call down
dumbledad Jul 17, 2021
b947452
Old WiP changes
dumbledad Jul 17, 2021
f7d22f4
Adding HelloEnsembleInference
dumbledad Jul 17, 2021
af079ae
run_ml changes with parameter
dumbledad Jul 17, 2021
38b12b7
Wi{ on mypy and tidy pre unit test fix
dumbledad Jul 18, 2021
c255774
mypy fixes
dumbledad Jul 18, 2021
2db1857
import fix so test discovery works
dumbledad Jul 18, 2021
21044e6
Fixing test discovery
dumbledad Jul 18, 2021
4f10169
WiP fixing unit test
dumbledad Jul 18, 2021
e0e31e6
Unit test works
dumbledad Jul 18, 2021
1d23735
file system test
dumbledad Jul 18, 2021
1486bff
Adding register and actually building ensemble
Jul 19, 2021
ff8fef1
Removed call to innnereye_config
Jul 19, 2021
ddb3c9f
unit test fix
dumbledad Jul 22, 2021
e71d49b
fix for the reference error on AzureML
dumbledad Jul 23, 2021
db18fdd
Merge branch 'main' into timregan/527-ensembles-for-BYOL-xval
dumbledad Jul 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions InnerEye/ML/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,28 @@ def find_recovery_checkpoint_and_epoch(path: Path) -> Optional[PathAndEpoch]:

def create_best_checkpoint(path: Path) -> Path:
"""
Creates the best checkpoint file. "Best" is at the moment defined as being the last checkpoint, but could be
based on some defined policy.
The best checkpoint will be renamed to `best_checkpoint.ckpt`.
Creates the best checkpoint file. "Best" is at the moment defined as being the checkpoint whose name matches
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX or the only available checkpoint, but it could be based on some defined
policy.
The best checkpoint will be renamed to BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX.
:param path: The folder that contains all checkpoint files.
"""
logging.debug(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
last_ckpt = path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
all_files = f"Existing files: {' '.join(p.name for p in path.glob('*'))}"
if not last_ckpt.is_file():
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found. {all_files}")
logging.info(f"Using {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} as the best checkpoint: Renaming to "
f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}")
candidate_checkpoint: Optional[Path] = None
checkpoint_files = list(path.glob('*.ckpt'))
if (path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX) in checkpoint_files:
candidate_checkpoint = path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
elif len(checkpoint_files) == 1:
candidate_checkpoint = checkpoint_files[0]
else:
raise FileNotFoundError(
f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found in ",
f"{str(' '.join(p.name for p in checkpoint_files))}, and there were ",
f"{len(checkpoint_files)} so the policy of falling back to the only checkpoint could not work.")
assert candidate_checkpoint # mypy
logging.info(
f"Using {candidate_checkpoint.name} as best checkpoint. Renaming it to {BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}")
best = path / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last_ckpt.rename(best)
candidate_checkpoint.rename(best)
return best


Expand Down
53 changes: 51 additions & 2 deletions InnerEye/ML/configs/other/HelloContainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from sklearn.model_selection import KFold

from InnerEye.Common import fixed_paths
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.lightning_container import InnerEyeEnsembleInference, LightningContainer


class HelloDataset(Dataset):
Expand Down Expand Up @@ -250,7 +251,8 @@ def __init__(self) -> None:
# This method must be overridden by any subclass of LightningContainer. It returns the model that you wish to
# train, as a LightningModule
def create_model(self) -> LightningModule:
return HelloRegression()
self._model = HelloRegression() # TODO: why does LightningContainer need all three, _model, model, and create_model?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I get the question? This method here should only return a freshly generated model, the _model properties are populated elsewhere.

return self._model

# This method must be overridden by any subclass of LightningContainer. It returns a data module, which
# in turn contains 3 data loaders for training, validation, and test set.
Expand All @@ -276,3 +278,50 @@ def create_report(self) -> None:
report = f"Performance on test set: MSE = {test_mse}, MAE = {test_mae}"
print(report)
Path("report.txt").write_text(report)


class HelloEnsembleInference(InnerEyeEnsembleInference):
"""
Ensemble collection intended to run inference over a collection of HelloRegression models hydrated from the
checkpoints of a HelloContainer cross validation training run.
"""
def __init__(self, outputs_folder: Optional[Path] = None) -> None:
super().__init__(outputs_folder)
self.test_mse: List[torch.Tensor] = []
self.test_mae = MeanAbsoluteError()
self.execution_mode: Optional[ModelExecutionMode] = None

def on_ensemble_inference_start(self) -> None:
"""
Initialize before any inference.
"""
super().on_ensemble_inference_start()
self.execution_mode = None

def on_ensemble_inference_start_dataset(self, execution_mode: ModelExecutionMode) -> None:
"""
Runs initialization for inference, when starting inference on a new dataset split (train/val/test).
:param execution_mode: Indicates whether the item comes from the training, validation or test set.
"""
super().on_ensemble_inference_start_dataset(execution_mode)
self.test_mse = []
self.test_mae.reset()
self.execution_mode = execution_mode

def record_ensemble_posterior(self, batch_y: torch.Tensor, batch_idx: int, posterior: torch.Tensor) -> None:
"""
Called when the model has finished making a prediction to compute metrics and store them.
"""
self.test_mse.append(torch.nn.functional.mse_loss(posterior, batch_y))
self.test_mae.update(preds=posterior, target=batch_y)

def on_ensemble_inference_end_dataset(self) -> None:
"""
Append the metrics from this dataset's inference run to the metrics' files.
"""
if self.outputs_folder:
average_mse = torch.mean(torch.stack(self.test_mse))
with (self.outputs_folder / "test_mse.txt").open("a") as test_mse_file:
test_mse_file.write(f"{str(self.execution_mode.name)}: {str(average_mse.item())}\n") # type: ignore
with (self.outputs_folder / "test_mae.txt").open("a") as test_mae_file:
test_mae_file.write(f"{str(self.execution_mode.name)}: {str(self.test_mae.compute().item())}\n") # type: ignore
8 changes: 5 additions & 3 deletions InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, DEFAULT_LOGS_DIR_NAME
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.Common.type_annotations import PathOrString, TupleFloat2
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode, create_unique_timestamp_id, \
get_best_checkpoint_path, get_recovery_checkpoint_path
from InnerEye.ML.common import (DATASET_CSV_FILE_NAME, ModelExecutionMode, create_unique_timestamp_id,
get_best_checkpoint_path, get_recovery_checkpoint_path)

# A folder inside of the outputs folder that will contain all information for running the model in inference mode

Expand Down Expand Up @@ -198,6 +198,8 @@ class WorkflowParams(param.Parameterized):
cross_validation_split_index: int = param.Integer(DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, bounds=(-1, None),
doc="The index of the cross validation fold this model is "
"associated with when performing k-fold cross validation")
ensemble_model_name: str = param.String(
doc=("The class name of the ensemble model to build from the cross validation checkpoints of a Lightning model."))
Comment on lines +201 to +202
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that this extra requirement is a key bottleneck in your design. Anybody who wants to implement an ensemble model now needs to define the model itself (ensemble member), and then another class that is the ensemble logic. For the other models, we have been able to achieve that without extra overhead - for example, we can define a Prostate model, and get Prostate ensemble models for free.

inference_on_train_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on training set after training.")
Expand Down Expand Up @@ -829,7 +831,7 @@ def load_checkpoint_and_modify(self, path_to_checkpoint: Path) -> Dict[str, Any]

def load_checkpoint(path_to_checkpoint: Path, use_gpu: bool = True) -> Dict[str, Any]:
"""
Loads a Torch checkpoint from the given file. If use_gpu==False, map all parameters to the GPU, otherwise
Loads a Torch checkpoint from the given file. If use_gpu==False, map all parameters to the CPU, otherwise
left the device of all parameters unchanged.
"""
import torch
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,5 +500,5 @@ def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
assert isinstance(self.trainer, Trainer)
self.log_on_epoch(MetricType.LOSS, loss, is_training)
if is_training:
learning_rate = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0]
learning_rate = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0] # type: ignore
self.log_on_epoch(MetricType.LEARNING_RATE, learning_rate, is_training)
168 changes: 148 additions & 20 deletions InnerEye/ML/lightning_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,25 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import abc
from typing import Any, Dict, Iterator, List, Optional, Tuple
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple

import param
import torch
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive import GridParameterSampling, HyperDriveConfig, PrimaryMetricGoal, choice
from pytorch_lightning import LightningDataModule, LightningModule
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive import GridParameterSampling, HyperDriveConfig, PrimaryMetricGoal, choice

from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.Common.generic_parsing import GenericConfig, create_from_matching_params
from InnerEye.Common.metrics_constants import TrackedMetrics
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, \
WorkflowParams
from InnerEye.ML.deep_learning_config import (DatasetParams, OptimizerParams, OutputParams, TrainerParams,
WorkflowParams, load_checkpoint)
from InnerEye.ML.utils import model_util
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
from InnerEye.ML.utils.run_recovery import RunRecovery
Expand All @@ -35,11 +37,11 @@ class InnerEyeInference(abc.ABC):

model.on_inference_start()
for dataset_split in [Train, Val, Test]
model.on_inference_epoch_start(dataset_split, is_ensemble_model=False)
for batch_idx, item in enumerate(dataloader[dataset_split])):
model_outputs = model.forward(item)
model.inference_step(item, batch_idx, model_outputs)
model.on_inference_epoch_end()
model.on_inference_start_dataset(dataset_split, is_ensemble_model=False)
for batch_idx, batch in enumerate(dataloader[dataset_split])):
posteriors = model.forward(batch)
model.record_posteriors(batch, batch_idx, posteriors)
model.on_inference_end_dataset()
model.on_inference_end()
"""

Expand All @@ -50,29 +52,28 @@ def on_inference_start(self) -> None:
"""
pass

def on_inference_epoch_start(self, dataset_split: ModelExecutionMode, is_ensemble_model: bool) -> None:
def on_inference_start_dataset(self, execution_mode: ModelExecutionMode) -> None:
"""
Runs initialization for inference, when starting inference on a new dataset split (train/val/test).
Depending on the settings, this can be called anywhere between 0 (no inference at all) to 3 times (inference
on all of train/val/test split).
:param dataset_split: Indicates whether the item comes from the training, validation or test set.
:param is_ensemble_model: If False, the model_outputs come from an individual model. If True, the model
outputs come from multiple models.
:param execution_mode: Indicates whether the item comes from the training, validation or test set.
"""
pass

def inference_step(self, batch: Any, batch_idx: int, model_output: torch.Tensor) -> None:
def record_posteriors(self, batch: Any, batch_idx: int, posteriors: torch.Tensor) -> None:
"""
This hook is called when the model has finished making a prediction. It can write the results to a file,
or compute metrics and store them.
:param batch: The batch of data for which the model made a prediction.
:param model_output: The model outputs. This would usually be a torch.Tensor, but can be any datatype.
:param batch_idx: The index of the batch.
:param posteriors: The posteriors output by the model.
"""
# We don't want abstract methods here, it avoids class creation for unit tests, and we also want this
# method to be left optional (it should be possible to also use Lightning's native test_step method)
raise NotImplementedError("Method on_inference_start must be overwritten in a derived class.")

def on_inference_epoch_end(self) -> None:
def on_inference_end_dataset(self) -> None:
"""
Called when the inference on one of the dataset splits (train/val/test) has finished.
Depending on the settings, this can be called anywhere between 0 (no inference at all) to 3 times (inference
Expand All @@ -87,10 +88,137 @@ def on_inference_end(self) -> None:
"""
pass

def aggregate_ensemble_model_outputs(self, model_outputs: Iterator[torch.Tensor]) -> torch.Tensor:

class InnerEyeEnsembleInference():
"""
InnerEyeEnsembleInference provides help for buiding ensemble models from cross validation runs of LightningModules,
and doing inference on the ensemble.

To set up an ensemble and then do inference, call
model.load_checkpoints_into_ensemble(checkpoints, use_gpu)
model.on_ensemble_inference_start()
for dataset_split in [Train, Val, Test]
model.on_ensemble_inference_start_dataset(dataset_split, is_ensemble_model=False)
for batch_idx, batch in enumerate(dataloader[dataset_split])):
posterior = model.ensemble_forward(batch)
model.record_ensemble_posterior(batch, batch_idx, posterior)
model.on_ensemble_inference_end_dataset()
model.on_ensemble_inference_end()

We have not duplicated the method documentation from InnerEyeInference, where you can find further explantaion of
the role of each method.
"""
def __init__(self, outputs_folder: Optional[Path] = None) -> None:
"""
Sets up the list of models that forms the ensemble. These models should inherit from both InnerEyeEinference and
LightiningModule, but since mypy does not have support for specifying interesection types we specify just one,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this?

InnerEyeInference, here and assert the other when needed.
:param outputs_folder: The root directory where we store metric files and reports. This is marked optional so
that subclasses may be instantiated without constructor arguments.
"""
self.outputs_folder = outputs_folder
self.ensemble_models: List[LightningModule] = []

def load_checkpoints_into_ensemble( # type: ignore
self,
exemplar: LightningModule,
checkpoint_paths: List[Path],
use_gpu: bool) -> None:
"""
Convenience method takes each checkpoint path from a list of checkpoint paths and adds it as an additional
member of the ensemble.
:param exemplar: Each model in the ensemble will be based on a copy of this exemplar.
:param checkpoint_paths: A list of paths to checkpoints for loading into new models in the ensemble.
:param use_gpu: Passed on eventaully to deep_learning_config.load_checkpoint.
"""
for checkpoint_path in checkpoint_paths:
self._load_checkpoint_into_ensemble(exemplar, checkpoint_path, use_gpu)

def _load_checkpoint_into_ensemble( # type: ignore
self,
exemplar: LightningModule,
checkpoint_path: Path,
use_gpu: bool) -> None:
"""
Load a single checkpoint path as an additional member of the ensemble.
:param exemplar: Each model in the ensemble will be based on a copy of this exemplar.
:param checkpoint_path: The path to the to checkpoint file to load into a new model in the ensemble.
:param use_gpu: Passed on to deep_learning_config.load_checkpoint.
"""
checkpoint = load_checkpoint(checkpoint_path, use_gpu)
# new_model = deepcopy(exemplar) # In AML, this line triggers "ReferenceError: weakly-referenced object no
# longer exists" so we will make a new instance instead
logging.info(f"Adding a {type(exemplar)} to the ensemble and loading its state from a checkpoint.")
new_model = type(exemplar)()
assert isinstance(new_model, LightningModule) # mypy
new_model.load_state_dict(checkpoint['state_dict'], strict=False)
self.ensemble_models.append(new_model)

def on_ensemble_inference_start(self) -> None:
"""
Set each model in the enseble into evaluation mode.
"""
for model in self.ensemble_models:
assert isinstance(model, LightningModule) # mypy
model.eval()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling model.eval() is critical. If users forget to call the superclass method, they will run their model in a wrong way. Can we bake that into local_checkpoint?


def on_ensemble_inference_start_dataset(self, execution_mode: ModelExecutionMode) -> None:
"""
Runs initialization for inference, when starting inference on a new dataset split (train/val/test).
:param execution_mode: Passed on to InnerEyeInference model in the ensemble.
"""
pass

def ensemble_forward(self, batch_x: torch.Tensor) -> torch.Tensor:
"""
Aggregate the posteriors from each model in the ensemble using our static method
`aggregate_ensemble_model_outputs`
:param batch_x: The batch of data for which the model made a prediction.
:returns: The aggregated ensemble outputs.
"""
model_outputs: List[torch.Tensor] = []
for model in self.ensemble_models:
assert isinstance(model, LightningModule) # mypy
model_outputs.append(model.forward(batch_x))
posterior = InnerEyeEnsembleInference.aggregate_ensemble_model_outputs(iter(model_outputs))
return posterior

def record_ensemble_posterior(self, batch_y: torch.Tensor, batch_idx: int, posterior: torch.Tensor) -> None:
"""
This hook is called when the model has finished making a prediction, and can write the results to a file, or
compute metrics and store them. In this base class we cannot know which metrics are needed so you will need to
override this method. Here are two ideas:

self.test_mse.append(torch.nn.functional.mse_loss(posterior, batch["y"]))
self.test_mae.update(preds=posterior, target=batch["y"])

(where `self.test_mse: List[torch.Tensor]` and `self.test_mae: MeanAbsoluteError` would be initialised in the
constructor).

:param batch_y: The batch of data for which the model made a prediction.
:param batch_idx: The index of the batch
:param posterior: The posterior, typically from ensemble_forward.
"""
pass

def on_ensemble_inference_end_dataset(self) -> None:
"""
Run after the inference run over a test/val/train dataset
"""
pass

def on_ensemble_inference_end(self) -> None:
"""
Run after all the test/val/train datasets' inference runs. Override to save metrics to disk and generate reports.
"""
pass

@staticmethod
def aggregate_ensemble_model_outputs(model_outputs: Iterator[torch.Tensor]) -> torch.Tensor:
"""
Aggregates the outputs of multiple models when using an ensemble model. In the default implementation,
this averages the tensors coming from all the models.
Aggregates the outputs of multiple models when using an ensemble model. In the default implementation, this
example averages the tensors coming from all of the models in the ensemble. Override this static method if you
need a different aggregation from the mean we use here.
:param model_outputs: An iterator over the model outputs for all ensemble members.
:return: The aggregate model outputs.
"""
Expand Down
Loading