Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Register all models after training, not only Segmentation models. (#455)
Browse files Browse the repository at this point in the history
This PR changes the codepath so all models trained on AzureML are registered. The codepath previously allowed only segmentation models (subclasses of `SegmentationModelBase`) to be registered. Models are registered after a training run or if the `only_register_model` flag is set. Models may be legacy InnerEye config-based models or may be defined using the LightningContainer class.

The PR also removes the AzureRunner conda environment. The full InnerEye conda environment is needed to submit a training job to AzureML.

It splits the `TrainHelloWorldAndHelloContainer` job in the PR build into two jobs, `TrainHelloWorld` and `TrainHelloContainer`. It adds a pytest marker `after_training_hello_container` for tests that can be run after training is finished in the `TrainHelloContainer` job.

This will solve the issue of model registration in #377 and #398.
  • Loading branch information
Shruthi42 committed May 12, 2021
1 parent 7b5b414 commit aa09b9d
Show file tree
Hide file tree
Showing 20 changed files with 260 additions and 283 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
Detection Challenge datasets. See
[SSL doc](https://github.com/microsoft/InnerEye-DeepLearning/blob/main/docs/self_supervised_models.md) for more
details.
- ([#455](https://github.com/microsoft/InnerEye-DeepLearning/pull/455)) All models trained on AzureML are registered.
The codepath previously allowed only segmentation models (subclasses of `SegmentationModelBase`) to be registered.
Models are registered after a training run or if the `only_register_model` flag is set. Models may be legacy InnerEye
config-based models or may be defined using the LightningContainer class.
Additionally, the `TrainHelloWorldAndHelloContainer` job in the PR build has been split into two jobs, `TrainHelloWorld` and
`TrainHelloContainer`. A pytest marker `after_training_hello_container` has been added to run tests after training is
finished in the `TrainHelloContainer` job.

### Changed

Expand Down Expand Up @@ -105,6 +112,8 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
### Removed
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Deprecated `start_epoch` config argument.
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Delete unused `classification_report.ipynb`.
- ([#455](https://github.com/microsoft/InnerEye-DeepLearning/pull/455)) Removed the AzureRunner conda environment.
The full InnerEye conda environment is needed to submit a training job to AzureML.

### Deprecated

Expand Down
2 changes: 0 additions & 2 deletions InnerEye/Common/fixed_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
SETTINGS_YAML_FILE = INNEREYE_PACKAGE_ROOT / SETTINGS_YAML_FILE_NAME

MODEL_INFERENCE_JSON_FILE_NAME = 'model_inference_config.json'
AZURE_RUNNER_ENVIRONMENT_YAML_FILE_NAME = "azure_runner.yml"
AZURE_RUNNER_ENVIRONMENT_YAML = repository_root_directory(AZURE_RUNNER_ENVIRONMENT_YAML_FILE_NAME)

# The names of files at the repository root that are required for running the inference pipeline.
SCORE_SCRIPT = "score.py"
Expand Down
10 changes: 5 additions & 5 deletions InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from typing import Any, List, Optional

import torch
from pytorch_lightning.metrics import Metric
from pl_bolts.models.self_supervised import SSLEvaluator
from torch.nn import functional as F

from InnerEye.ML.SSL.encoders import get_encoder_output_dim
from InnerEye.ML.dataset.scalar_sample import ScalarItem
from InnerEye.ML.lightning_container import LightningModuleWithOptimizer
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve, \
ScalarMetricsBase
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule


Expand All @@ -37,9 +37,9 @@ def __init__(self,
n_classes=num_classes,
p=0.20)
if self.num_classes == 2:
self.train_metrics: List[ScalarMetricsBase] = \
self.train_metrics: List[Metric] = \
[AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(), Accuracy05()]
self.val_metrics: List[ScalarMetricsBase] = \
self.val_metrics: List[Metric] = \
[AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(), Accuracy05()]
else:
# Note that for multi-class, Accuracy05 is the standard multi-class accuracy.
Expand All @@ -48,7 +48,7 @@ def __init__(self,

def on_train_start(self) -> None:
for metric in [*self.train_metrics, *self.val_metrics]:
metric.to(device=self.device)
metric.to(device=self.device) # type: ignore

def train(self, mode: bool = True) -> Any:
self.classifier_head.train(mode)
Expand Down
8 changes: 4 additions & 4 deletions InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics import Metric
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from torch import Tensor as T
from torch.nn import functional as F

from InnerEye.ML.SSL.utils import SSLDataModuleType
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve, \
ScalarMetricsBase
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve

BatchType = Union[Dict[SSLDataModuleType, Any], Any]

Expand All @@ -36,10 +36,10 @@ def __init__(self,
self.weight_decay = 1e-4
self.learning_rate = learning_rate

self.train_metrics: List[ScalarMetricsBase] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(),
self.train_metrics: List[Metric] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(),
Accuracy05()] \
if self.num_classes == 2 else [Accuracy05()]
self.val_metrics: List[ScalarMetricsBase] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(),
self.val_metrics: List[Metric] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(),
Accuracy05()] \
if self.num_classes == 2 else [Accuracy05()]
self.class_weights = class_weights
Expand Down
11 changes: 3 additions & 8 deletions InnerEye/ML/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
from enum import Enum, unique
from math import isclose
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import param
from azureml.core import Model, ScriptRunConfig
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive import HyperDriveConfig
from pandas import DataFrame

from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Common.common_util import ModelProcessing, any_pairwise_larger, any_smaller_or_equal_than, check_is_any_of
from InnerEye.Common.common_util import any_pairwise_larger, any_smaller_or_equal_than, check_is_any_of
from InnerEye.Common.generic_parsing import IntTuple
from InnerEye.Common.type_annotations import TupleFloat2, TupleFloat3, TupleInt3, TupleStringOptionalFloat
from InnerEye.ML.common import ModelExecutionMode
Expand Down Expand Up @@ -796,7 +795,3 @@ def get_cropped_image_sample_transforms(self) -> ModelTransformsPerExecutionMode
By default no transformation is performed.
"""
return ModelTransformsPerExecutionMode()


PostCrossValidationHookSignature = Callable[[ModelConfigBase, Path], None]
ModelDeploymentHookSignature = Callable[[SegmentationModelBase, AzureConfig, Model, ModelProcessing], Any]
10 changes: 10 additions & 0 deletions InnerEye/ML/model_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import List
from pathlib import Path

from dataclasses_json import dataclass_json

Expand All @@ -24,3 +25,12 @@ def __post_init__(self) -> None:
long_paths = list(filter(is_long_path, self.checkpoint_paths))
if long_paths:
raise ValueError(f"Following paths: {long_paths} are greater than {MAX_PATH_LENGTH}")


def read_model_inference_config(path_to_model_inference_config: Path) -> ModelInferenceConfig:
"""
Read the model inference configuration from a json file, and instantiate a ModelInferenceConfig object using this.
"""
model_inference_config_json = path_to_model_inference_config.read_text(encoding='utf-8')
model_inference_config = ModelInferenceConfig.from_json(model_inference_config_json) # type: ignore
return model_inference_config
Loading

0 comments on commit aa09b9d

Please sign in to comment.