-
Notifications
You must be signed in to change notification settings - Fork 141
Enable building an ensemble model from the cross validation checkpoints of a BYO Lightning model #529
Conversation
…dels But only the check, not the inference loop yet
Trying to work out how checkpoints are passed between training and inference, or where they should be by tracing through new model in run_ml
Itno two for unit testing with ensemble from xval checkpoints
I do not know why the instantiated model_config cannot live in deep_learning_config
Need to fix documentation and check it works for real on AML
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good first step, but I feel the interface is not yet as easy as it should be. Anybody who wants to use ensemble models now has to define two models - and I think that's simply not going to fly. Can you think of a simpler solution?
As for testing - you have added a lot of complicated switching logic, that I was not fully able to digest. Is there a way to test that too?
@@ -250,7 +251,8 @@ def __init__(self) -> None: | |||
# This method must be overridden by any subclass of LightningContainer. It returns the model that you wish to | |||
# train, as a LightningModule | |||
def create_model(self) -> LightningModule: | |||
return HelloRegression() | |||
self._model = HelloRegression() # TODO: why does LightningContainer need all three, _model, model, and create_model? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I get the question? This method here should only return a freshly generated model, the _model properties are populated elsewhere.
ensemble_model_name: str = param.String( | ||
doc=("The class name of the ensemble model to build from the cross validation checkpoints of a Lightning model.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that this extra requirement is a key bottleneck in your design. Anybody who wants to implement an ensemble model now needs to define the model itself (ensemble member), and then another class that is the ensemble logic. For the other models, we have been able to achieve that without extra overhead - for example, we can define a Prostate model, and get Prostate ensemble models for free.
def __init__(self, outputs_folder: Optional[Path] = None) -> None: | ||
""" | ||
Sets up the list of models that forms the ensemble. These models should inherit from both InnerEyeEinference and | ||
LightiningModule, but since mypy does not have support for specifying interesection types we specify just one, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get this?
""" | ||
for model in self.ensemble_models: | ||
assert isinstance(model, LightningModule) # mypy | ||
model.eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling model.eval() is critical. If users forget to call the superclass method, they will run their model in a wrong way. Can we bake that into local_checkpoint?
self.model_config_loader = model_config_loader | ||
self.ensemble_model: Optional[InnerEyeEnsembleInference] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you find a way of not requiring a separate class for ensemble models, this would become obsolete, and hence simplify the code.
# 0, then wait for the sibling runs, build the ensemble model, and write a report for that. | ||
if not self.is_offline_run and PARENT_RUN_CONTEXT is not None: | ||
sibling_runs_checkpoint_handler = self.wait_and_collect_sibling_runs_if_required() | ||
logging.info("DEBUGGING: about to create_ensemble_model_and_run_inference_from_lightningmodule_checkpoints") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logging.debug instead?
:param lightning_model: The LightningContainer container to be used. | ||
:param checkpoint_paths: The path to the checkpoint that should be used for inference. | ||
""" | ||
lightning_model = lightning_container.create_model() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you should create the model again here - it should already be accessible via lightning_container.model?
# Register the model, and then run inference as required. No models should be registered when running outside | ||
# AzureML. | ||
if not self.is_offline_run: | ||
if self.should_register_model(): | ||
self.register_model(checkpoint_paths, ModelProcessing.ENSEMBLE_CREATION) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fragment also exists somewhere else?
Replaced by #549 |
Closing #527
The method MLRunner.run_inference_for_lightning_models takes a list of checkpoint paths as an argument, but then makes sure that there is only one used (here):
We want to change this so that the checkpoints gleaned from a BYOL cross validation run can be used as an ensemble model.
AB#4219