This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 142
/
test_deepmil.py
379 lines (321 loc) · 14.3 KB
/
test_deepmil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from unittest.mock import MagicMock
import pytest
import torch
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
from torchmetrics import Accuracy, Metric # noqa
from torchvision.models import resnet18
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
MeanPoolingLayer,
)
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.configs.histo_configs.classification.DeepSMILECrck import (
DeepSMILECrck,
)
from InnerEye.ML.configs.histo_configs.classification.DeepSMILEPanda import (
DeepSMILEPanda,
)
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.default_paths import (
TCGA_CRCK_DATASET_DIR,
PANDA_TILES_DATASET_DIR,
)
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
from InnerEye.ML.Histopathology.utils.naming import MetricsKey, ResultsKey
def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
def _test_lightningmodule(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
assert n_classes > 0
# hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere
encoder = get_supervised_imagenet_encoder()
module = DeepMILModule(
encoder=encoder,
label_column="label",
n_classes=n_classes,
pooling_layer=pooling_layer,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate,
)
bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim])
bag_labels_list = []
bag_logits_list = []
bag_attn_list = []
for bag in bag_images:
if n_classes > 1:
labels = randint(n_classes, size=(max_bag_size,))
else:
labels = randint(n_classes + 1, size=(max_bag_size,))
bag_labels_list.append(module.get_bag_label(labels))
logit, attn = module(bag)
assert logit.shape == (1, n_classes)
assert attn.shape == (module.pool_out_dim, max_bag_size)
bag_logits_list.append(logit.view(-1))
bag_attn_list.append(attn)
bag_logits = stack(bag_logits_list)
bag_labels = stack(bag_labels_list).view(-1)
assert bag_logits.shape[0] == (batch_size)
assert bag_labels.shape[0] == (batch_size)
if module.n_classes > 1:
loss = module.loss_fn(bag_logits, bag_labels)
else:
loss = module.loss_fn(bag_logits.squeeze(1), bag_labels.float())
assert loss > 0
assert loss.shape == ()
probs = module.activation_fn(bag_logits)
assert ((probs >= 0) & (probs <= 1)).all()
if n_classes > 1:
assert probs.shape == (batch_size, n_classes)
else:
assert probs.shape[0] == batch_size
if n_classes > 1:
preds = argmax(probs, dim=1)
else:
preds = round(probs)
assert preds.shape[0] == batch_size
for metric_name, metric_object in module.train_metrics.items():
if metric_name == MetricsKey.CONF_MATRIX:
continue
score = metric_object(preds.view(-1, 1), bag_labels.view(-1, 1))
assert torch.all(score >= 0)
assert torch.all(score <= 1)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=pooling_layer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_mean_pooling(
n_classes: int,
batch_size: int,
max_bag_size: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=MeanPoolingLayer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=1,
pool_out_dim=1,
dropout_rate=dropout_rate)
def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
def is_integral(x: torch.Tensor) -> bool:
return (x == x.long()).all() # type: ignore
assert scores.shape == labels.shape
assert torch.is_floating_point(scores), "Received scores with integer dtype"
assert not is_integral(scores), "Received scores with integral values"
assert is_integral(labels), "Received labels with floating-point values"
def add_callback(fn: Callable, callback: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Any:
callback(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
def test_metrics() -> None:
input_dim = (128,)
module = DeepMILModule(
encoder=IdentityEncoder(input_dim=input_dim),
label_column=TilesDataset.LABEL_COLUMN,
n_classes=1,
pooling_layer=AttentionLayer,
)
# Patching to enable running the module without a Trainer object
module.trainer = MagicMock(world_size=1) # type: ignore
module.log = MagicMock() # type: ignore
batch_size = 20
bag_size = 5
class_weights = torch.tensor([.8, .2])
bags: List[Dict] = []
for slide_idx in range(batch_size):
bag_label = torch.multinomial(class_weights, 1)
sample: Dict[str, Iterable] = {
TilesDataset.SLIDE_ID_COLUMN: [str(slide_idx)] * bag_size,
TilesDataset.TILE_ID_COLUMN: [f"{slide_idx}-{tile_idx}"
for tile_idx in range(bag_size)],
TilesDataset.IMAGE_COLUMN: rand(bag_size, *input_dim),
TilesDataset.LABEL_COLUMN: bag_label.expand(bag_size),
}
sample[TilesDataset.PATH_COLUMN] = [tile_id + '.png'
for tile_id in sample[TilesDataset.TILE_ID_COLUMN]]
bags.append(sample)
batch = default_collate(bags)
# ================
# Test that the module metrics match manually computed metrics with the correct inputs
module_metrics_dict = module.test_metrics
independent_metrics_dict = module.get_metrics()
# Patch the metrics to check that the inputs are valid. In particular, test that the scores
# do not have integral values, which would suggest that hard labels were passed instead.
for metric_obj in module_metrics_dict.values():
metric_obj.update = add_callback(metric_obj.update, validate_metric_inputs)
results = module.test_step(batch, 0)
predicted_probs = results[ResultsKey.PROB]
true_labels = results[ResultsKey.TRUE_LABEL]
for key, metric_obj in module_metrics_dict.items():
value = metric_obj.compute()
expected_value = independent_metrics_dict[key](predicted_probs, true_labels)
assert torch.allclose(value, expected_value), f"Discrepancy in '{key}' metric"
# ================
# Test that thresholded metrics (e.g. accuracy, precision, etc.) change as the threshold is varied.
# If they don't, it suggests the inputs are hard labels instead of continuous scores.
thresholded_metrics_keys = [key for key, metric in module_metrics_dict.items()
if hasattr(metric, 'threshold')]
def set_metrics_threshold(metrics_dict: Any, threshold: float) -> None:
for key in thresholded_metrics_keys:
metrics_dict[key].threshold = threshold
def reset_metrics(metrics_dict: Any) -> None:
for metric_obj in metrics_dict.values():
metric_obj.reset()
low_threshold, high_threshold = torch.quantile(predicted_probs, torch.tensor([0.1, 0.9]))
reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=low_threshold)
_ = module.test_step(batch, 0)
results_low_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}
reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=high_threshold)
_ = module.test_step(batch, 0)
results_high_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}
for key in thresholded_metrics_keys:
assert not torch.allclose(results_low_threshold[key], results_high_threshold[key]), \
f"Got same value for '{key}' metric with low and high thresholds"
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu"
return {
key: [
value.to(device) if isinstance(value, Tensor) else value for value in values
]
for key, values in batch.items()
}
CONTAINER_DATASET_DIR = {
DeepSMILEPanda: PANDA_TILES_DATASET_DIR,
DeepSMILECrck: TCGA_CRCK_DATASET_DIR,
}
@pytest.mark.parametrize("container_type", [DeepSMILEPanda,
DeepSMILECrck])
@pytest.mark.parametrize("use_gpu", [True, False])
def test_container(container_type: Type[LightningContainer], use_gpu: bool) -> None:
dataset_dir = CONTAINER_DATASET_DIR[container_type]
if not os.path.isdir(dataset_dir):
pytest.skip(
f"Dataset for container {container_type.__name__} "
f"is unavailable: {dataset_dir}"
)
if container_type is DeepSMILECrck:
container = DeepSMILECrck(encoder_type=ImageNetEncoder.__name__)
elif container_type is DeepSMILEPanda:
container = DeepSMILEPanda(encoder_type=ImageNetEncoder.__name__)
else:
container = container_type()
container.setup()
data_module: TilesDataModule = container.get_data_module() # type: ignore
data_module.max_bag_size = 10
module = container.create_model()
if use_gpu:
module.cuda()
train_data_loader = data_module.train_dataloader()
for batch_idx, batch in enumerate(train_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
loss = module.training_step(batch, batch_idx)
loss.retain_grad()
loss.backward()
assert loss.grad is not None
assert loss.shape == ()
assert isinstance(loss, Tensor)
break
val_data_loader = data_module.val_dataloader()
for batch_idx, batch in enumerate(val_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
loss = module.validation_step(batch, batch_idx)
assert loss.shape == () # noqa
assert isinstance(loss, Tensor)
break
test_data_loader = data_module.test_dataloader()
for batch_idx, batch in enumerate(test_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
outputs_dict = module.test_step(batch, batch_idx)
loss = outputs_dict[ResultsKey.LOSS] # noqa
assert loss.shape == ()
assert isinstance(loss, Tensor)
break
def test_class_weights_binary() -> None:
class_weights = Tensor([0.5, 3.5])
n_classes = 1
module = DeepMILModule(
encoder=get_supervised_imagenet_encoder(),
label_column="label",
n_classes=n_classes,
pooling_layer=AttentionLayer,
pool_hidden_dim=5,
pool_out_dim=1,
class_weights=class_weights,
)
logits = Tensor(randn(1, n_classes))
bag_label = randint(n_classes + 1, size=(1,))
pos_weight = Tensor([class_weights[1] / (class_weights[0] + 1e-5)])
loss_weighted = module.loss_fn(logits.squeeze(1), bag_label.float())
criterion_unweighted = nn.BCEWithLogitsLoss()
loss_unweighted = criterion_unweighted(logits.squeeze(1), bag_label.float())
if bag_label.item() == 1:
assert allclose(loss_weighted, pos_weight * loss_unweighted)
else:
assert allclose(loss_weighted, loss_unweighted)
def test_class_weights_multiclass() -> None:
class_weights = Tensor([0.33, 0.33, 0.33])
n_classes = 3
module = DeepMILModule(
encoder=get_supervised_imagenet_encoder(),
label_column="label",
n_classes=n_classes,
pooling_layer=AttentionLayer,
pool_hidden_dim=5,
pool_out_dim=1,
class_weights=class_weights,
)
logits = Tensor(randn(1, n_classes))
bag_label = randint(n_classes, size=(1,))
loss_weighted = module.loss_fn(logits, bag_label)
criterion_unweighted = nn.CrossEntropyLoss()
loss_unweighted = criterion_unweighted(logits, bag_label)
# The weighted and unweighted loss functions give the same loss values for batch_size = 1.
# https://stackoverflow.com/questions/67639540/pytorch-cross-entropy-loss-weights-not-working
# TODO: the test should reflect actual weighted loss operation for the class weights after batch_size > 1 is implemented.
assert allclose(loss_weighted, loss_unweighted)