Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix DeepMIL metrics bug (#674)
Browse files Browse the repository at this point in the history
* Fix DeepMIL metrics input bug

* Add first version of metrics tests

* Update submodule

* Add test for DeepMIL metrics inputs

* Clean-up and update submodule

* Update changelog

* Upgrade mlflow due to Component Governance warning
  • Loading branch information
dccastro committed Mar 1, 2022
1 parent d7e5d8b commit e984554
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ in inference-only runs when using lightning containers.
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training
- ([#652](https://github.com/microsoft/InnerEye-DeepLearning/pull/652)) Run pytest build on Windows after Linux agent version upgrade
- ([#655](https://github.com/microsoft/InnerEye-DeepLearning/pull/655)) Run pytest on Linux again, but with Ubuntu 20.04
- ([#674](https://github.com/microsoft/InnerEye-DeepLearning/pull/674)) Fix DeepMIL metrics bug whereby hard labels were used instead of probabilities.

### Removed

Expand Down
25 changes: 13 additions & 12 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,13 @@ def get_metrics(self) -> nn.ModuleDict:
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
else:
return nn.ModuleDict({MetricsKey.ACC: Accuracy(),
threshold = 0.5
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.PRECISION: Precision(),
MetricsKey.RECALL: Recall(),
MetricsKey.F1: F1(),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes+1)})
MetricsKey.PRECISION: Precision(threshold=threshold),
MetricsKey.RECALL: Recall(threshold=threshold),
MetricsKey.F1: F1(threshold=threshold),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})

def log_metrics(self,
stage: str) -> None:
Expand Down Expand Up @@ -238,24 +239,24 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK
else:
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())

probs = self.activation_fn(bag_logits)
predicted_probs = self.activation_fn(bag_logits)
if self.n_classes > 1:
preds = argmax(probs, dim=1)
predicted_labels = argmax(predicted_probs, dim=1)
else:
preds = round(probs)
predicted_labels = round(predicted_probs)

loss = loss.view(-1, 1)
preds = preds.view(-1, 1)
probs = probs.view(-1, 1)
predicted_labels = predicted_labels.view(-1, 1)
predicted_probs = predicted_probs.view(-1, 1)
bag_labels = bag_labels.view(-1, 1)

results = dict()
for metric_object in self.get_metrics_dict(stage).values():
metric_object.update(preds, bag_labels)
metric_object.update(predicted_probs, bag_labels)
results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN],
ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN],
ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss,
ResultsKey.PROB: probs, ResultsKey.PRED_LABEL: preds,
ResultsKey.PROB: predicted_probs, ResultsKey.PRED_LABEL: predicted_labels,
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})

Expand Down
109 changes: 106 additions & 3 deletions Tests/ML/histopathology/models/test_deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
# ------------------------------------------------------------------------------------------

import os
from typing import Callable, Dict, List, Optional, Type # noqa
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from unittest.mock import MagicMock

import pytest
import torch
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
from torchmetrics import Accuracy, Metric # noqa
from torchvision.models import resnet18

from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
Expand All @@ -30,7 +33,7 @@
PANDA_TILES_DATASET_DIR,
)
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder, TileEncoder
from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
from InnerEye.ML.Histopathology.utils.naming import MetricsKey, ResultsKey


Expand Down Expand Up @@ -157,6 +160,106 @@ def test_lightningmodule_mean_pooling(
dropout_rate=dropout_rate)


def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
def is_integral(x: torch.Tensor) -> bool:
return (x == x.long()).all() # type: ignore

assert scores.shape == labels.shape
assert torch.is_floating_point(scores), "Received scores with integer dtype"
assert not is_integral(scores), "Received scores with integral values"
assert is_integral(labels), "Received labels with floating-point values"


def add_callback(fn: Callable, callback: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Any:
callback(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper


def test_metrics() -> None:
input_dim = (128,)
module = DeepMILModule(
encoder=IdentityEncoder(input_dim=input_dim),
label_column=TilesDataset.LABEL_COLUMN,
n_classes=1,
pooling_layer=AttentionLayer,
)

# Patching to enable running the module without a Trainer object
module.trainer = MagicMock(world_size=1) # type: ignore
module.log = MagicMock() # type: ignore

batch_size = 20
bag_size = 5
class_weights = torch.tensor([.8, .2])
bags: List[Dict] = []
for slide_idx in range(batch_size):
bag_label = torch.multinomial(class_weights, 1)
sample: Dict[str, Iterable] = {
TilesDataset.SLIDE_ID_COLUMN: [str(slide_idx)] * bag_size,
TilesDataset.TILE_ID_COLUMN: [f"{slide_idx}-{tile_idx}"
for tile_idx in range(bag_size)],
TilesDataset.IMAGE_COLUMN: rand(bag_size, *input_dim),
TilesDataset.LABEL_COLUMN: bag_label.expand(bag_size),
}
sample[TilesDataset.PATH_COLUMN] = [tile_id + '.png'
for tile_id in sample[TilesDataset.TILE_ID_COLUMN]]
bags.append(sample)
batch = default_collate(bags)

# ================
# Test that the module metrics match manually computed metrics with the correct inputs
module_metrics_dict = module.test_metrics
independent_metrics_dict = module.get_metrics()

# Patch the metrics to check that the inputs are valid. In particular, test that the scores
# do not have integral values, which would suggest that hard labels were passed instead.
for metric_obj in module_metrics_dict.values():
metric_obj.update = add_callback(metric_obj.update, validate_metric_inputs)

results = module.test_step(batch, 0)
predicted_probs = results[ResultsKey.PROB]
true_labels = results[ResultsKey.TRUE_LABEL]

for key, metric_obj in module_metrics_dict.items():
value = metric_obj.compute()
expected_value = independent_metrics_dict[key](predicted_probs, true_labels)
assert torch.allclose(value, expected_value), f"Discrepancy in '{key}' metric"

# ================
# Test that thresholded metrics (e.g. accuracy, precision, etc.) change as the threshold is varied.
# If they don't, it suggests the inputs are hard labels instead of continuous scores.
thresholded_metrics_keys = [key for key, metric in module_metrics_dict.items()
if hasattr(metric, 'threshold')]

def set_metrics_threshold(metrics_dict: Any, threshold: float) -> None:
for key in thresholded_metrics_keys:
metrics_dict[key].threshold = threshold

def reset_metrics(metrics_dict: Any) -> None:
for metric_obj in metrics_dict.values():
metric_obj.reset()

low_threshold, high_threshold = torch.quantile(predicted_probs, torch.tensor([0.1, 0.9]))

reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=low_threshold)
_ = module.test_step(batch, 0)
results_low_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}

reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=high_threshold)
_ = module.test_step(batch, 0)
results_high_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}

for key in thresholded_metrics_keys:
assert not torch.allclose(results_low_threshold[key], results_high_threshold[key]), \
f"Got same value for '{key}' metric with low and high thresholds"


def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu"
return {
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies:
- jupyter-client==6.1.5
- lightning-bolts==0.4.0
- matplotlib==3.3.0
- mlflow==1.17.0
- mlflow==1.23.1
- monai==0.6.0
- mypy==0.910
- mypy-extensions==0.4.3
Expand Down

0 comments on commit e984554

Please sign in to comment.