This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 141
/
lightning_metrics.py
334 lines (270 loc) · 13.6 KB
/
lightning_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
from typing import Any, Iterator, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import metrics
from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics.functional import accuracy, auc, auroc, precision_recall_curve, roc
from torch.nn import ModuleList
from InnerEye.Common.metrics_constants import AVERAGE_DICE_SUFFIX, MetricType, TRAIN_PREFIX, VALIDATION_PREFIX
def nanmean(values: torch.Tensor) -> torch.Tensor:
"""
Computes the average of all values in the tensor, skipping those entries that are NaN (not a number).
If all values are NaN, the result is also NaN.
:param values: The values to average.
:return: A scalar tensor containing the average.
"""
valid = values[~torch.isnan(values.view((-1,)))]
if valid.numel() == 0:
return torch.tensor([math.nan]).type_as(values)
return valid.mean()
class MeanAbsoluteError(metrics.MeanAbsoluteError):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.name = MetricType.MEAN_ABSOLUTE_ERROR.value
@property
def has_predictions(self) -> bool:
"""
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return self.total > 0 # type: ignore
class MeanSquaredError(metrics.MeanSquaredError):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.name = MetricType.MEAN_SQUARED_ERROR.value
@property
def has_predictions(self) -> bool:
"""
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return self.total > 0 # type: ignore
class ExplainedVariance(metrics.ExplainedVariance):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.name = MetricType.EXPLAINED_VAR.value
@property
def has_predictions(self) -> bool:
"""
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return len(self.y_pred) > 0 # type: ignore
class Accuracy05(metrics.Accuracy):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.name = MetricType.ACCURACY_AT_THRESHOLD_05.value
@property
def has_predictions(self) -> bool:
"""
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return self.total > 0 # type: ignore
class AverageWithoutNan(Metric):
"""
A generic metric computer that keep track of the average of all values excluding those that are NaN.
"""
def __init__(self, dist_sync_on_step: bool = False, name: str = ""):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.name = name
def update(self, value: torch.Tensor) -> None: # type: ignore
"""
Stores all the given individual elements of the given tensor in the present object.
"""
for v in value.view((-1,)):
if not torch.isnan(v):
self.sum = self.sum + v # type: ignore
self.count = self.count + 1 # type: ignore
def compute(self) -> torch.Tensor:
if self.count == 0.0:
raise ValueError("No values stored, or only NaN values have so far been fed into this object.")
return self.sum / self.count
class ScalarMetricsBase(Metric):
"""
A base class for all metrics that can only be computed once the complete set of model predictions and labels
is available. The base class provides an `update` method, and synchronized storage for predictions (field `preds`)
and labels (field `targets`). Derived classes need to override the `compute` method.
"""
def __init__(self, name: str = "", compute_from_logits: bool = False):
super().__init__(dist_sync_on_step=False)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("targets", default=[], dist_reduce_fx=None)
self.name = name
self.compute_from_logits = compute_from_logits
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: # type: ignore
self.preds.append(preds) # type: ignore
self.targets.append(targets) # type: ignore
def compute(self) -> torch.Tensor:
"""
Computes a metric from the stored predictions and targets.
"""
raise NotImplementedError("Should be implemented in the child classes")
def _get_preds_and_targets(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Gets a tuple of (concatenated predictions, concatenated targets).
"""
return torch.cat(self.preds), torch.cat(self.targets) # type: ignore
@property
def has_predictions(self) -> bool:
"""
Returns True if the present object stores at least 1 prediction (self.update has been called at least once),
or False if no predictions are stored.
"""
return len(self.preds) > 0 # type: ignore
def _get_metrics_at_optimal_cutoff(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the ROC to find the optimal cut-off i.e. the probability threshold for which the
difference between true positive rate and false positive rate is smallest. Then, computes
the false positive rate, false negative rate and accuracy at this threshold (i.e. when the
predicted probability is higher than the threshold the predicted label is 1 otherwise 0).
:returns: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
"""
preds, targets = self._get_preds_and_targets()
if torch.unique(targets).numel() == 1:
return torch.tensor(np.nan), torch.tensor(np.nan), torch.tensor(np.nan), torch.tensor(np.nan)
fpr, tpr, thresholds = roc(preds, targets)
assert isinstance(fpr, torch.Tensor)
assert isinstance(tpr, torch.Tensor)
assert isinstance(thresholds, torch.Tensor)
optimal_idx = torch.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
acc = accuracy(preds > optimal_threshold, targets)
false_negative_optimal = 1 - tpr[optimal_idx]
false_positive_optimal = fpr[optimal_idx]
return optimal_threshold, false_positive_optimal, false_negative_optimal, acc
class AccuracyAtOptimalThreshold(ScalarMetricsBase):
"""
Computes the binary classification accuracy at an optimal cut-off point.
"""
def __init__(self) -> None:
super().__init__(name=MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD.value)
def compute(self) -> torch.Tensor:
return self._get_metrics_at_optimal_cutoff()[3]
class OptimalThreshold(ScalarMetricsBase):
"""
Computes the optimal cut-off point for a binary classifier.
"""
def __init__(self) -> None:
super().__init__(name=MetricType.OPTIMAL_THRESHOLD.value)
def compute(self) -> torch.Tensor:
return self._get_metrics_at_optimal_cutoff()[0]
class FalsePositiveRateOptimalThreshold(ScalarMetricsBase):
"""
Computes the false positive rate when choosing the optimal cut-off point for a binary classifier.
"""
def __init__(self) -> None:
super().__init__(name=MetricType.FALSE_POSITIVE_RATE_AT_OPTIMAL_THRESHOLD.value)
def compute(self) -> torch.Tensor:
return self._get_metrics_at_optimal_cutoff()[1]
class FalseNegativeRateOptimalThreshold(ScalarMetricsBase):
"""
Computes the false negative rate when choosing the optimal cut-off point for a binary classifier.
"""
def __init__(self) -> None:
super().__init__(name=MetricType.FALSE_NEGATIVE_RATE_AT_OPTIMAL_THRESHOLD.value)
def compute(self) -> torch.Tensor:
return self._get_metrics_at_optimal_cutoff()[2]
class AreaUnderRocCurve(ScalarMetricsBase):
"""
Computes the area under the receiver operating curve (ROC).
"""
def __init__(self) -> None:
super().__init__(name=MetricType.AREA_UNDER_ROC_CURVE.value)
def compute(self) -> torch.Tensor:
preds, targets = self._get_preds_and_targets()
if torch.unique(targets).numel() == 1:
return torch.tensor(np.nan)
return auroc(preds, targets)
class AreaUnderPrecisionRecallCurve(ScalarMetricsBase):
"""
Computes the area under the precision-recall-curve.
"""
def __init__(self) -> None:
super().__init__(name=MetricType.AREA_UNDER_PR_CURVE.value)
def compute(self) -> torch.Tensor:
preds, targets = self._get_preds_and_targets()
if torch.unique(targets).numel() == 1:
return torch.tensor(np.nan)
prec, recall, _ = precision_recall_curve(preds, targets)
return auc(recall, prec) # type: ignore
class BinaryCrossEntropyWithLogits(ScalarMetricsBase):
"""
Computes the cross entropy for binary classification.
This metric must be computed off the model output logits.
"""
def __init__(self) -> None:
super().__init__(name=MetricType.CROSS_ENTROPY.value, compute_from_logits=True)
def compute(self) -> torch.Tensor:
preds, targets = self._get_preds_and_targets()
# All classification metrics work with integer targets, but this one does not. Convert to float.
return F.binary_cross_entropy_with_logits(input=preds, target=targets.to(dtype=preds.dtype))
class MetricForMultipleStructures(torch.nn.Module):
"""
Stores a metric for multiple structures, and an average Dice score across all structures.
The class consumes pre-computed metric values, and only keeps an aggregate for later computing the
averages. When averaging, metric values that are NaN are skipped.
"""
def __init__(self, ground_truth_ids: List[str], is_training: bool,
metric_name: str = MetricType.DICE.value,
use_average_across_structures: bool = True) -> None:
"""
Creates a new MetricForMultipleStructures object.
:param ground_truth_ids: The list of anatomical structures that should be stored.
:param metric_name: The name of the metric that should be stored. This is used in the names of the individual
metrics.
:param is_training: If true, use "train/" as the prefix for all metric names, otherwise "val/"
:param use_average_across_structures: If True, keep track of the average metric value across structures,
while skipping NaNs. If false, only store the per-structure metric values.
"""
super().__init__()
prefix = (TRAIN_PREFIX if is_training else VALIDATION_PREFIX) + metric_name + "/"
# All Metric classes must be
self.average_per_structure = ModuleList([AverageWithoutNan(name=prefix + g) for g in ground_truth_ids])
self.use_average_across_structures = use_average_across_structures
if use_average_across_structures:
self.average_all = AverageWithoutNan(name=prefix + AVERAGE_DICE_SUFFIX)
self.count = len(ground_truth_ids)
def update(self, values_per_structure: torch.Tensor) -> None:
"""
Stores a vector of per-structure Dice scores in the present object. It updates the per-structure values,
and the aggregate value across all structures.
:param values_per_structure: A row tensor that has as many entries as there are ground truth IDs.
"""
if values_per_structure.dim() != 1 or values_per_structure.numel() != self.count:
raise ValueError(f"Expected a tensor with {self.count} elements, but "
f"got shape {values_per_structure.shape}")
for i, v in enumerate(values_per_structure.view((-1,))):
self.average_per_structure[i].update(v)
if self.use_average_across_structures:
self.average_all.update(nanmean(values_per_structure))
def __iter__(self) -> Iterator[Metric]:
"""
Enumerates all the metrics that the present object holds: First the average across all structures,
then the per-structure Dice scores.
"""
if self.use_average_across_structures:
yield self.average_all
yield from self.average_per_structure
def compute_all(self) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Calls the .compute() method on all the metrics that the present object holds, and returns a sequence
of (metric name, metric value) tuples. This will automatically also call .reset() on the metrics.
The first returned metric is the average across all structures, then come the per-structure values.
"""
for d in self:
yield d.name, d.compute() # type: ignore
def reset(self) -> None:
"""
Calls the .reset() method on all the metrics that the present object holds.
"""
for d in self:
d.reset()