-
Notifications
You must be signed in to change notification settings - Fork 0
/
roc_auc.py
57 lines (44 loc) · 2.21 KB
/
roc_auc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import Tensor
def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor:
"""Area under the receiver operating characteristic curve (ROC AUC).
Unlike scikit-learn's implementation, this function supports batched inputs of
shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples
within each dataset. This is primarily useful for efficiently computing bootstrap
confidence intervals.
Args:
y_true: Ground truth tensor of shape `(N,)` or `(N, n)`.
y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`.
Returns:
Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D,
a tensor of shape (N,) containing the ROC AUC for each dataset.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")
if not ((y_true == 1) | (y_true == 0)).all():
raise ValueError("y_true should contain only 0s and 1s")
# Sort y_pred in descending order and get indices
indices = y_pred.argsort(descending=True, dim=-1)
# Reorder y_true based on sorted y_pred indices
y_true_sorted = y_true.gather(-1, indices)
# Calculate number of positive and negative samples
num_positives = y_true.sum(dim=-1)
num_negatives = y_true.shape[-1] - num_positives
# Calculate cumulative sum of true positive counts (TPs)
tps = torch.cumsum(y_true_sorted, dim=-1)
# Calculate cumulative sum of false positive counts (FPs)
fps = torch.cumsum(1 - y_true_sorted, dim=-1)
# Calculate true positive rate (TPR) and false positive rate (FPR)
tpr = tps / num_positives.view(-1, 1)
fpr = fps / num_negatives.view(-1, 1)
# Calculate differences between consecutive FPR values (widths of trapezoids)
fpr_diffs = torch.cat(
[fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1
)
# Calculate area under the ROC curve for each dataset using trapezoidal rule
return torch.sum(tpr * fpr_diffs, dim=-1).squeeze()