Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Add subsampling transform and mean pooling (#656)
Browse files Browse the repository at this point in the history
* Add subsampling transform

* Add option to allow_missing_keys for Subsampled

* Add dropout param to BaseMIL

* Add docstring and tests for Subsampled

* Update changelog

* Update to hi-ml with mean pooling

* Enable mean pooling in DeepMIL

* Add/refactor mean pooling tests

* Update changelog

* Update to latest hi-ml with mean pooling
  • Loading branch information
dccastro committed Feb 21, 2022
1 parent 1600ef3 commit e2ec5cc
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs that run in AzureML.
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling.

### Changed
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with set_grad_enabled(self.is_finetune):
instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
bag_features = bag_features.view(1, -1)
bag_logit = self.classifier_fn(bag_features)
return bag_logit, attentions

Expand Down
43 changes: 41 additions & 2 deletions InnerEye/ML/Histopathology/models/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import PIL
from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import MapTransform
from monai.transforms.transform import MapTransform, Randomizable
from torchvision.transforms.functional import to_tensor

from InnerEye.ML.Histopathology.models.encoders import TileEncoder
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self,
allow_missing_keys: bool = False,
chunk_size: int = 0) -> None:
"""
:param keys: Key(s) for the image path(s) in the input dictionary.
:param keys: Key(s) for the image tensor(s) in the input dictionary.
:param encoder: The tile encoder to use for feature extraction.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
Expand Down Expand Up @@ -128,3 +128,42 @@ def __call__(self, data: Mapping) -> Mapping:
for key in self.key_iterator(out_data):
out_data[key] = self._encode_tiles(data[key])
return out_data


def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
if isinstance(data, (np.ndarray, torch.Tensor)):
return data[indices]
elif isinstance(data, Sequence):
return [data[i] for i in indices]
else:
raise ValueError(f"Data of type {type(data)} is not indexable")


class Subsampled(MapTransform, Randomizable):
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""

def __init__(self, keys: KeysCollection, max_size: int,
allow_missing_keys: bool = False) -> None:
"""
:param keys: Key(s) for all batch elements that must be subsampled.
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
random down to `max_size` along their first dimension. If shorter, the elements are merely
shuffled.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
"""
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.max_size = max_size
self._indices: np.ndarray

def randomize(self, total_size: int) -> None:
subsample_size = min(self.max_size, total_size)
self._indices = self.R.choice(total_size, size=subsample_size)

def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
size = len(data[self.keys[0]])
self.randomize(size)
for key in self.key_iterator(out_data):
out_data[key] = take_indices(data[key], self._indices)
return out_data
4 changes: 3 additions & 1 deletion InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn
from torchvision.models.resnet import resnet18

from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
Expand Down Expand Up @@ -90,6 +90,8 @@ def get_pooling_layer(self) -> Type[nn.Module]:
return AttentionLayer
elif self.pooling_type == GatedAttentionLayer.__name__:
return GatedAttentionLayer
elif self.pooling_type == MeanPoolingLayer.__name__:
return MeanPoolingLayer
else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")

Expand Down
54 changes: 46 additions & 8 deletions Tests/ML/histopathology/models/test_deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
MeanPoolingLayer,
)

from InnerEye.ML.lightning_container import LightningContainer
Expand All @@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)


@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule(
def _test_lightningmodule(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
Expand Down Expand Up @@ -119,6 +113,50 @@ def test_lightningmodule(
assert torch.all(score <= 1)


@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=pooling_layer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate)


@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_mean_pooling(
n_classes: int,
batch_size: int,
max_bag_size: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=MeanPoolingLayer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=1,
pool_out_dim=1,
dropout_rate=dropout_rate)


def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu"
return {
Expand Down
57 changes: 56 additions & 1 deletion Tests/ML/histopathology/models/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Callable, Sequence, Union

import numpy as np
import pytest
import torch
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
Expand All @@ -19,7 +20,7 @@
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
from Tests.ML.util import assert_dicts_equal


Expand Down Expand Up @@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
bagged_subset,
transform=transform,
cache_subdir="TCGA-CRCk_embed_cache")


@pytest.mark.parametrize('include_non_indexable', [True, False])
@pytest.mark.parametrize('allow_missing_keys', [True, False])
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
batch_size = 5
max_size = batch_size // 2
data = {
'array_1d': np.random.randn(batch_size),
'array_2d': np.random.randn(batch_size, 4),
'tensor_1d': torch.randn(batch_size),
'tensor_2d': torch.randn(batch_size, 4),
'list': torch.randn(batch_size).tolist(),
'indices': list(range(batch_size)),
'non-indexable': 42,
}

keys_to_subsample = list(data.keys())
if not include_non_indexable:
keys_to_subsample.remove('non-indexable')
keys_to_subsample.append('missing-key')

subsampling = Subsampled(keys_to_subsample, max_size=max_size,
allow_missing_keys=allow_missing_keys)

if include_non_indexable:
with pytest.raises(ValueError):
sub_data = subsampling(data)
return
elif not allow_missing_keys:
with pytest.raises(KeyError):
sub_data = subsampling(data)
return
else:
sub_data = subsampling(data)

assert set(sub_data.keys()) == set(data.keys())

# Check lenghts before and after subsampling
for key in keys_to_subsample:
if key not in data:
continue # Skip missing keys
assert len(data[key]) == batch_size # type: ignore
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore

# Check contents of subsampled elements
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
for idx, elem in zip(sub_data['indices'], sub_data[key]):
assert np.array_equal(elem, data[key][idx]) # type: ignore

# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
sub_data2 = subsampling(data)
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore
2 changes: 1 addition & 1 deletion hi-ml
Submodule hi-ml updated 164 files

0 comments on commit e2ec5cc

Please sign in to comment.