Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Throw error if invalid argument (#215)
Browse files Browse the repository at this point in the history
Add a check to generic parsing that throws an error if we feed in an invalid argument to the config. And correct invalid configs in tests.
  • Loading branch information
melanibe committed Sep 14, 2020
1 parent 13fa984 commit 71cf63e
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 102 deletions.
19 changes: 13 additions & 6 deletions InnerEye/Common/generic_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,30 @@ class GenericConfig(param.Parameterized):
Base class for all configuration classes provides helper functionality to create argparser.
"""

def __init__(self, should_validate: bool = True, **params: Any):
def __init__(self, should_validate: bool = True, throw_if_unknown_param: bool = False, **params: Any):
"""
Instantiates the config class, ignoring parameters that are not overridable.
:param should_validate: If True, the validate() method is called directly after init.
:param throw_if_unknown_param: If True, raise an error if the provided "params" contains any key that does not
correspond to an attribute of the class.
:param params: Parameters to set.
"""
# check if illegal arguments are passed in
legal_params = self.get_overridable_parameters()
illegal = [k for k, v in params.items() if (k in self.params().keys()) and (k not in legal_params)]

if illegal:
raise ValueError(f"The following parameters cannot be overriden as they are either "
f"readonly, constant, or private members : {illegal}")
else:
# set known arguments
super().__init__(**{k: v for k, v in params.items() if k in legal_params.keys()})
if should_validate:
self.validate()
if throw_if_unknown_param:
# check if parameters not defined by the config class are passed in
unknown = [k for k, v in params.items() if (k not in self.params().keys())]
if unknown:
raise ValueError(f"The following parameters do not exist: {unknown}")
# set known arguments
super().__init__(**{k: v for k, v in params.items() if k in legal_params.keys()})
if should_validate:
self.validate()

def validate(self) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/configs/classification/DummyClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(self) -> None:
num_dataload_workers=0,
test_start_epoch=num_epochs,
use_mixed_precision=True,
subject_column="subjectID",
conv_in_3d=True
subject_column="subjectID"
)
self.conv_in_3d = True
self.expected_image_size_zyx = (4, 5, 7)

def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(self, **params: Any) -> None:
# This should be annotated as torch.utils.data.Dataset, but we don't want to import torch here.
self._datasets_for_training: Optional[Dict[ModelExecutionMode, Any]] = None
self._datasets_for_inference: Optional[Dict[ModelExecutionMode, Any]] = None
super().__init__(**params)
super().__init__(throw_if_unknown_param=True, **params)
logging.info("Creating the default output folder structure.")
self.create_filesystem(fixed_paths.repository_root_directory())

Expand Down
179 changes: 92 additions & 87 deletions Tests/Common/test_build_config.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,92 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path

import pytest

from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.Common.build_config import BUILDINFORMATION_JSON, ExperimentResultLocation, \
build_information_to_dot_net_json, build_information_to_dot_net_json_file
from InnerEye.Common.output_directories import TestOutputDirectories
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.scalar_config import ScalarModelBase


def test_build_config(test_output_dirs: TestOutputDirectories) -> None:
"""
Test that json with build information is created correctly.
"""
config = AzureConfig(
build_number=42,
build_user="user",
build_branch="branch",
build_source_id="00deadbeef",
build_source_author="author",
tag="tag",
model="model")
result_location = ExperimentResultLocation(azure_job_name="job")
net_json = build_information_to_dot_net_json(config, result_location)
expected = '{"BuildNumber": 42, "BuildRequestedFor": "user", "BuildSourceBranchName": "branch", ' \
'"BuildSourceVersion": "00deadbeef", "BuildSourceAuthor": "author", "ModelName": "model", ' \
'"ResultsContainerName": null, "ResultsUri": null, "DatasetFolder": null, "DatasetFolderUri": null, ' \
'"AzureBatchJobName": "job"}'
assert expected == net_json
result_folder = Path(test_output_dirs.root_dir) / "buildinfo"
build_information_to_dot_net_json_file(config, result_location, folder=result_folder)
result_file = result_folder / BUILDINFORMATION_JSON
assert result_file.exists()
assert result_file.read_text() == expected


def test_fields_are_set() -> None:
"""
Tests that expected fields are set when creating config classes.
"""
expected = [("hello", None), ("world", None)]
config = SegmentationModelBase(
should_validate=False,
ground_truth_ids=[x[0] for x in expected],
largest_connected_component_foreground_classes=expected
)
assert hasattr(config, CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY)
assert config.largest_connected_component_foreground_classes == expected


def test_config_non_overridable_params() -> None:
"""
Check error raised if attempt to override non overridable configs
"""
non_overridable_params = {k: v.default for k, v in ModelConfigBase.params().items()
if k not in ModelConfigBase.get_overridable_parameters()}
with pytest.raises(ValueError) as ex:
ModelConfigBase(
should_validate=False,
**non_overridable_params
)
assert "The following parameters cannot be overriden" in ex.value.args[0]


@pytest.mark.gpu
def test_dataset_reader_workers() -> None:
"""
Test to make sure the number of dataset reader workers are set correctly
"""
config = ScalarModelBase(
should_validate=False,
num_dataset_reader_workers=-1
)
if config.is_offline_run:
assert config.num_dataset_reader_workers == -1
else:
assert config.num_dataset_reader_workers == 0


# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path

import pytest

from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.Common.build_config import BUILDINFORMATION_JSON, ExperimentResultLocation, \
build_information_to_dot_net_json, build_information_to_dot_net_json_file
from InnerEye.Common.output_directories import TestOutputDirectories
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.scalar_config import ScalarModelBase


def test_build_config(test_output_dirs: TestOutputDirectories) -> None:
"""
Test that json with build information is created correctly.
"""
config = AzureConfig(
build_number=42,
build_user="user",
build_branch="branch",
build_source_id="00deadbeef",
build_source_author="author",
tag="tag",
model="model")
result_location = ExperimentResultLocation(azure_job_name="job")
net_json = build_information_to_dot_net_json(config, result_location)
expected = '{"BuildNumber": 42, "BuildRequestedFor": "user", "BuildSourceBranchName": "branch", ' \
'"BuildSourceVersion": "00deadbeef", "BuildSourceAuthor": "author", "ModelName": "model", ' \
'"ResultsContainerName": null, "ResultsUri": null, "DatasetFolder": null, "DatasetFolderUri": null, ' \
'"AzureBatchJobName": "job"}'
assert expected == net_json
result_folder = Path(test_output_dirs.root_dir) / "buildinfo"
build_information_to_dot_net_json_file(config, result_location, folder=result_folder)
result_file = result_folder / BUILDINFORMATION_JSON
assert result_file.exists()
assert result_file.read_text() == expected


def test_fields_are_set() -> None:
"""
Tests that expected fields are set when creating config classes.
"""
expected = [("hello", None), ("world", None)]
config = SegmentationModelBase(
should_validate=False,
ground_truth_ids=[x[0] for x in expected],
largest_connected_component_foreground_classes=expected
)
assert hasattr(config, CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY)
assert config.largest_connected_component_foreground_classes == expected


def test_config_non_overridable_params() -> None:
"""
Check error raised if attempt to override non overridable configs
"""
non_overridable_params = {k: v.default for k, v in ModelConfigBase.params().items()
if k not in ModelConfigBase.get_overridable_parameters()}
with pytest.raises(ValueError) as ex:
ModelConfigBase(
should_validate=False,
**non_overridable_params
)
assert "The following parameters cannot be overriden" in ex.value.args[0]


@pytest.mark.gpu
def test_config_with_typo() -> None:
with pytest.raises(ValueError) as ex:
ModelConfigBase(num_epochsi=100)
assert "The following parameters do not exist: ['num_epochsi']" in ex.value.args[0]


@pytest.mark.gpu
def test_dataset_reader_workers() -> None:
"""
Test to make sure the number of dataset reader workers are set correctly
"""
config = ScalarModelBase(
should_validate=False,
num_dataset_reader_workers=-1
)
if config.is_offline_run:
assert config.num_dataset_reader_workers == -1
else:
assert config.num_dataset_reader_workers == 0
3 changes: 2 additions & 1 deletion Tests/ML/configs/ClassificationModelForTesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class ClassificationModelForTesting(ScalarModelBase):
def __init__(self, conv_in_3d: bool = True, mean_teacher_model: bool = False) -> None:
num_epochs = 4
mean_teacher_alpha = 0.99 if mean_teacher_model else None
super().__init__(
local_dataset=full_ml_test_data_path("classification_data"),
image_channels=["image"],
Expand All @@ -29,7 +30,7 @@ def __init__(self, conv_in_3d: bool = True, mean_teacher_model: bool = False) ->
num_dataload_workers=0,
test_start_epoch=num_epochs,
subject_column="subjectID",
compute_mean_teacher_model=mean_teacher_model
mean_teacher_alpha=mean_teacher_alpha
)
self.expected_image_size_zyx = (4, 5, 7)
self.conv_in_3d = conv_in_3d
Expand Down
3 changes: 2 additions & 1 deletion Tests/ML/configs/ClassificationModelForTesting2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class ClassificationModelForTesting2D(ScalarModelBase):
def __init__(self, conv_in_3d: bool = True, mean_teacher_model: bool = False) -> None:
num_epochs = 4
mean_teacher_alpha = 0.99 if mean_teacher_model else None
super().__init__(
local_dataset=full_ml_test_data_path("classification_data_2d"),
image_channels=["image"],
Expand All @@ -29,7 +30,7 @@ def __init__(self, conv_in_3d: bool = True, mean_teacher_model: bool = False) ->
num_dataload_workers=0,
test_start_epoch=num_epochs,
subject_column="subjectID",
compute_mean_teacher_model=mean_teacher_model
mean_teacher_alpha=mean_teacher_alpha
)
self.expected_image_size_zyx = (5, 7)
self.conv_in_3d = conv_in_3d
Expand Down
1 change: 0 additions & 1 deletion Tests/ML/datasets/test_sequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,6 @@ def test_sequence_dataset_all(test_output_dirs: TestOutputDirectories) -> None:
categorical_columns=["META", "BETA"],
sequence_column="seq",
num_dataload_workers=0,
num_datsource=0,
train_batch_size=2,
should_validate=False,
shuffle=False
Expand Down
2 changes: 1 addition & 1 deletion Tests/ML/models/test_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_get_total_number_of_validation_epochs() -> None:
temperature_scaling_config=TemperatureScalingConfig())
assert c.get_total_number_of_validation_epochs() == 3
c = SequenceModelBase(num_epochs=2, sequence_target_positions=[1], temperature_scaling_config=None,
save_start_epoch=1, save_step_epoch=1, should_validate=False)
save_start_epoch=1, save_step_epochs=1, should_validate=False)
assert c.get_total_number_of_validation_epochs() == 2
c = SequenceModelBase(num_epochs=2, sequence_target_positions=[1],
save_start_epoch=1, save_step_epochs=1, should_validate=False,
Expand Down
1 change: 0 additions & 1 deletion Tests/ML/pipelines/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_inference_identity(image_size: Any,
crop_size=crop_size,
image_channels=list(map(str, range(num_channels))),
ground_truth_ids=ground_truth_ids,
crop_size_multiple=1,
should_validate=False,
posterior_smoothing_mm=posterior_smoothing_mm
)
Expand Down
2 changes: 1 addition & 1 deletion TestsOutsidePackage/test_register_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_get_child_paths(is_ensemble: bool, extra_code_directory: str) -> None:
checkpoints = checkpoint_paths * 2 if is_ensemble else checkpoint_paths
path_to_root = tests_root_directory().parent
azure_config = AzureConfig(extra_code_directory=extra_code_directory)
fake_model = ModelConfigBase(model_name="FakeModelToTestRegistration", azure_dataset_id="fake_dataset_id")
fake_model = ModelConfigBase(azure_dataset_id="fake_dataset_id")
ml_runner = MLRunner(model_config=fake_model, azure_config=azure_config, project_root=path_to_root)
child_paths = ml_runner.get_child_paths(checkpoints)
assert fixed_paths.ENVIRONMENT_YAML_FILE_NAME in child_paths
Expand Down

0 comments on commit 71cf63e

Please sign in to comment.