Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fixing bug in case of NaN labels (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
melanibe committed Sep 1, 2020
1 parent 58bbf3d commit 90eac82
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
6 changes: 4 additions & 2 deletions InnerEye/ML/pipelines/scalar_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ def predict(self, sample: Dict[str, Any]) -> ScalarInferencePipelineBase.Result:
if len(set(map(tuple, [result.subject_ids for result in results]))) > 1: # type: ignore
raise ValueError("Trying to aggregate results for different subject ids.")
subject_ids = results[0].subject_ids
# check that we have the same subject ids
# check that we have the same labels
for result in results:
if not torch.equal(results[0].labels, result.labels):
# Using allclose() instead of equal() because we can have NaN in the labels (in which case
# equal() would return False).
if not torch.allclose(results[0].labels, result.labels, atol=0, rtol=0, equal_nan=True):
raise ValueError("Trying to aggregate results but ground truth does not match across samples.")
labels = results[0].labels

Expand Down
11 changes: 5 additions & 6 deletions Tests/ML/pipelines/test_scalar_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_create_from_checkpoint_ensemble() -> None:


def test_create_result_dataclass() -> None:

# invalid instances: these try to instantiate with inconsistent length lists/tensors
with pytest.raises(ValueError):
# one sample, but labels has length 2
Expand Down Expand Up @@ -103,16 +102,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore


@pytest.mark.parametrize('batch_size', [1, 3])
def test_predict_non_ensemble(batch_size: int) -> None:

@pytest.mark.parametrize("empty_labels", [True, False])
def test_predict_non_ensemble(batch_size: int, empty_labels: bool) -> None:
config = ClassificationModelForTesting()
model: Any = ScalarOnesModel(config.expected_image_size_zyx, 1.)
update_model_for_mixed_precision_and_parallel(ModelAndInfo(model),
args=config,
execution_mode=ModelExecutionMode.TEST)
pipeline = ScalarInferencePipeline(model, config, 0, 0)
actual_labels = torch.zeros((batch_size, 1)) * np.nan if empty_labels else torch.zeros((batch_size, 1))
data = {"metadata": [GeneralSampleMetadata(id='2')] * batch_size,
"label": torch.zeros((batch_size, 1)),
"label": actual_labels,
"images": torch.zeros(((batch_size, 1) + config.expected_image_size_zyx)),
"numerical_non_image_features": torch.tensor([]),
"categorical_non_image_features": torch.tensor([]),
Expand All @@ -121,14 +121,13 @@ def test_predict_non_ensemble(batch_size: int) -> None:
results = pipeline.predict(data)
ids, labels, predicted = results.subject_ids, results.labels, results.model_outputs
assert ids == ['2'] * batch_size
assert torch.equal(labels, torch.zeros((batch_size, 1)))
assert torch.allclose(labels, actual_labels, equal_nan=True)
# The model always returns 1, so predicted should be sigmoid(1)
assert torch.allclose(predicted, torch.full((batch_size, 1), 0.731058578))


@pytest.mark.parametrize('batch_size', [1, 3])
def test_predict_ensemble(batch_size: int) -> None:

config = ClassificationModelForTesting()
model_returns_0: Any = ScalarOnesModel(config.expected_image_size_zyx, 0.)
model_returns_1: Any = ScalarOnesModel(config.expected_image_size_zyx, 1.)
Expand Down

0 comments on commit 90eac82

Please sign in to comment.