Skip to content

Commit

Permalink
Switch to PyTorch AUROC impl
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 11, 2024
1 parent c99e534 commit 08a03ff
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ keywords = ["ai", "interpretability", "generalization"]
license = {text = "MIT License"}
dependencies = [
"datasets",
"sklearn",
"torch",
# 4.0 introduced the breaking change of using return_dict=True by default
"transformers>=4.0.0",
Expand Down
55 changes: 55 additions & 0 deletions w2s/roc_auc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from torch import Tensor


def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor:
"""Area under the receiver operating characteristic curve (ROC AUC).
Unlike scikit-learn's implementation, this function supports batched inputs of
shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples
within each dataset. This is primarily useful for efficiently computing bootstrap
confidence intervals.
Args:
y_true: Ground truth tensor of shape `(N,)` or `(N, n)`.
y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`.
Returns:
Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D,
a tensor of shape (N,) containing the ROC AUC for each dataset.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")

# Sort y_pred in descending order and get indices
indices = y_pred.argsort(descending=True, dim=-1)

# Reorder y_true based on sorted y_pred indices
y_true_sorted = y_true.gather(-1, indices)

# Calculate number of positive and negative samples
num_positives = y_true.sum(dim=-1)
num_negatives = y_true.shape[-1] - num_positives

# Calculate cumulative sum of true positive counts (TPs)
tps = torch.cumsum(y_true_sorted, dim=-1)

# Calculate cumulative sum of false positive counts (FPs)
fps = torch.cumsum(1 - y_true_sorted, dim=-1)

# Calculate true positive rate (TPR) and false positive rate (FPR)
tpr = tps / num_positives.view(-1, 1)
fpr = fps / num_negatives.view(-1, 1)

# Calculate differences between consecutive FPR values (widths of trapezoids)
fpr_diffs = torch.cat(
[fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1
)

# Calculate area under the ROC curve for each dataset using trapezoidal rule
return torch.sum(tpr * fpr_diffs, dim=-1).squeeze()
8 changes: 4 additions & 4 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
get_peft_model,
)
from simple_parsing import Serializable, field, parse
from sklearn.metrics import roc_auc_score
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Expand All @@ -20,6 +19,7 @@

from .ds_registry import load_and_process_dataset
from .knn import gather_hiddens, knn_average
from .roc_auc import roc_auc


@dataclass
Expand Down Expand Up @@ -79,10 +79,10 @@ def weak_processor(examples):
return out

def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions, labels = map(torch.from_numpy, eval_pred)
return dict(
accuracy=np.mean(np.argmax(predictions, axis=1) == labels),
auroc=roc_auc_score(labels, predictions[:, 1]),
accuracy=predictions.argmax(dim=1).eq(labels).float().mean(),
auroc=roc_auc(labels, predictions[:, 1]),
)

splits = load_and_process_dataset(
Expand Down

0 comments on commit 08a03ff

Please sign in to comment.