Skip to content

Commit

Permalink
Blazing fast bootstrap stderrs for AUROC (EleutherAI#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 16, 2023
1 parent b702eda commit 7b4a00c
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
args: ["--skip=*.yaml"]
args: ["-L fpr", "--skip=*.yaml"]
7 changes: 4 additions & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def evaluate_reporter(
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)

stats_row["lr_auroc"] = lr_auroc
lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)
stats_row["lr_auroc"] = lr_auroc_res.estimate
stats_row["lr_auroc_lower"] = lr_auroc_res.lower
stats_row["lr_auroc_upper"] = lr_auroc_res.upper
stats_row["lr_acc"] = lr_acc

row_buf.append(stats_row)
Expand Down
122 changes: 122 additions & 0 deletions elk/metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import NamedTuple

import torch
from torch import Tensor


Expand Down Expand Up @@ -34,3 +37,122 @@ def accuracy(y_true: Tensor, y_pred: Tensor) -> float:
hard_preds = y_pred.argmax(-1)

return hard_preds.cpu().eq(y_true.cpu()).float().mean().item()


class RocAucResult(NamedTuple):
"""Named tuple for storing ROC AUC results."""

estimate: float
"""Point estimate of the ROC AUC computed on this sample."""
lower: float
"""Lower bound of the bootstrap confidence interval."""
upper: float
"""Upper bound of the bootstrap confidence interval."""


def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor:
"""Area under the receiver operating characteristic curve (ROC AUC).
Unlike scikit-learn's implementation, this function supports batched inputs of
shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples
within each dataset. This is primarily useful for efficiently computing bootstrap
confidence intervals.
Args:
y_true: Ground truth tensor of shape `(N,)` or `(N, n)`.
y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`.
Returns:
Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D,
a tensor of shape (N,) containing the ROC AUC for each dataset.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")

# Sort y_pred in descending order and get indices
indices = y_pred.argsort(descending=True, dim=-1)

# Reorder y_true based on sorted y_pred indices
y_true_sorted = y_true.gather(-1, indices)

# Calculate number of positive and negative samples
num_positives = y_true.sum(dim=-1)
num_negatives = y_true.shape[-1] - num_positives

# Calculate cumulative sum of true positive counts (TPs)
tps = torch.cumsum(y_true_sorted, dim=-1)

# Calculate cumulative sum of false positive counts (FPs)
fps = torch.cumsum(1 - y_true_sorted, dim=-1)

# Calculate true positive rate (TPR) and false positive rate (FPR)
tpr = tps / num_positives.view(-1, 1)
fpr = fps / num_negatives.view(-1, 1)

# Calculate differences between consecutive FPR values (widths of trapezoids)
fpr_diffs = torch.cat(
[fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1
)

# Calculate area under the ROC curve for each dataset using trapezoidal rule
return torch.sum(tpr * fpr_diffs, dim=-1).squeeze()


def roc_auc_ci(
y_true: Tensor,
y_pred: Tensor,
*,
num_samples: int = 1000,
level: float = 0.95,
seed: int = 42,
) -> RocAucResult:
"""Bootstrap confidence interval for the ROC AUC.
Args:
y_true: Ground truth tensor of shape `(N,)`.
y_pred: Predicted class tensor of shape `(N,)`.
num_samples (int): Number of bootstrap samples to use.
level (float): Confidence level of the confidence interval.
seed (int): Random seed for reproducibility.
Returns:
RocAucResult: Named tuple containing the lower and upper bounds of the
confidence interval, along with the point estimate.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() != 1:
raise ValueError("y_true and y_pred should be 1D tensors")

device = y_true.device
N = y_true.shape[0]

# Generate random indices for bootstrap samples (shape: [num_bootstraps, N])
rng = torch.Generator(device=device).manual_seed(seed)
indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng)

# Create bootstrap samples of true labels and predicted probabilities
y_true_bootstraps = y_true[indices]
y_pred_bootstraps = y_pred[indices]

# Compute ROC AUC scores for bootstrap samples
bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps)

# Calculate the lower and upper bounds of the confidence interval. We use
# nanquantile instead of quantile because some bootstrap samples may have
# NaN values due to the fact that they have only one class.
alpha = (1 - level) / 2
q = y_pred.new_tensor([alpha, 1 - alpha])
lower, upper = bootstrap_aucs.nanquantile(q).tolist()

# Compute the point estimate
estimate = roc_auc(y_true, y_pred).item()
return RocAucResult(estimate, lower, upper)
4 changes: 2 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_device(self, devices, world_size: int) -> str:

def prepare_data(
self, device: str, layer: int, split_type: Literal["train", "val"]
) -> dict[str, tuple[Tensor, Tensor, np.ndarray | None]]:
) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]:
"""Prepare data for the specified layer and split type."""
out = {}

Expand All @@ -98,7 +98,7 @@ def prepare_data(
labels = assert_type(Tensor, split["label"])
val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))

with split.formatted_as("numpy"):
with split.formatted_as("torch", device=device):
has_preds = "model_preds" in split.features
lm_preds = split["model_preds"] if has_preds else None

Expand Down
6 changes: 3 additions & 3 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch import Tensor
from torch.nn.functional import binary_cross_entropy as bce

from ..metrics import roc_auc
from ..parsing import parse_loss
from ..utils.typing import assert_type
from .classifier import Classifier
Expand Down Expand Up @@ -175,8 +175,8 @@ def check_separability(
pseudo_preds = pseudo_clf(
# b v d -> (b v) d
torch.cat([val_x0, val_x1]).flatten(0, 1)
)
return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu()))
).squeeze(-1)
return roc_auc(pseudo_val_labels, pseudo_preds).item()

def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
loss = sum(
Expand Down
15 changes: 9 additions & 6 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import torch.nn as nn
from einops import rearrange, repeat
from simple_parsing.helpers import Serializable
from sklearn.metrics import roc_auc_score
from torch import Tensor

from ..calibration import CalibrationError
from ..metrics import accuracy, to_one_hot
from ..metrics import accuracy, roc_auc_ci, to_one_hot


class EvalResult(NamedTuple):
Expand All @@ -23,9 +22,12 @@ class EvalResult(NamedTuple):
which contains the loss, accuracy, calibrated accuracy, and AUROC.
"""

auroc: float
auroc_lower: float
auroc_upper: float

acc: float
cal_acc: float
auroc: float
ece: float


Expand Down Expand Up @@ -117,12 +119,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
raw_preds = to_one_hot(logits.argmax(dim=-1), c).long()
Y = to_one_hot(Y, c).long().flatten()

auroc = roc_auc_score(Y.cpu(), logits.cpu().flatten())
raw_acc = accuracy(Y, raw_preds.flatten())

auroc_result = roc_auc_ci(Y, logits.flatten())
return EvalResult(
auroc=auroc_result.estimate,
auroc_lower=auroc_result.lower,
auroc_upper=auroc_result.upper,
acc=float(raw_acc),
cal_acc=cal_acc,
auroc=float(auroc),
ece=cal_err,
)
13 changes: 6 additions & 7 deletions elk/training/supervised.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import torch
from einops import rearrange, repeat
from sklearn.metrics import roc_auc_score
from torch import Tensor

from ..metrics import accuracy, to_one_hot
from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot
from ..utils import assert_type
from .classifier import Classifier


def evaluate_supervised(
lr_model: Classifier, val_h: Tensor, val_labels: Tensor
) -> tuple[float, float]:
(n, v, k, d) = val_h.shape
) -> tuple[RocAucResult, float]:
(_, v, k, _) = val_h.shape

with torch.no_grad():
logits = rearrange(lr_model(val_h).cpu().squeeze(), "n v k -> (n v) k")
logits = rearrange(lr_model(val_h).squeeze(), "n v k -> (n v) k")
raw_preds = to_one_hot(logits.argmax(dim=-1), k).long()

labels = repeat(val_labels, "n -> (n v)", v=v)
labels = to_one_hot(labels, k).flatten()

lr_acc = accuracy(labels, raw_preds.flatten())
lr_auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten())
lr_auroc = roc_auc_ci(labels, logits.flatten())

return assert_type(float, lr_auroc), assert_type(float, lr_acc)
return lr_auroc, assert_type(float, lr_acc)


def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier:
Expand Down
38 changes: 18 additions & 20 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import torch
from einops import rearrange, repeat
from simple_parsing import Serializable, field, subgroups
from sklearn.metrics import roc_auc_score

from ..extraction.extraction import Extract
from ..metrics import accuracy, to_one_hot
from ..metrics import accuracy, roc_auc_ci, to_one_hot
from ..run import Run
from ..training.supervised import evaluate_supervised, train_supervised
from ..utils import select_usable_devices
Expand Down Expand Up @@ -148,37 +147,36 @@ def train_reporter(
row_buf = []
for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items():
val_result = reporter.score(val_gt, val_h)

if val_lm_preds is not None:
(_, v, k, _) = val_h.shape

val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu()
val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...")
val_lm_auroc = roc_auc_score(
to_one_hot(val_gt_cpu, k).long().flatten(), val_lm_preds.flatten()
)

val_lm_acc = accuracy(val_gt_cpu, torch.from_numpy(val_lm_preds))
else:
val_lm_auroc = None
val_lm_acc = None

row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
**val_result._asdict(),
"lm_auroc": val_lm_auroc,
"lm_acc": val_lm_acc,
}
)

if val_lm_preds is not None:
(_, v, k, _) = val_h.shape

val_gt_rep = repeat(val_gt, "n -> (n v)", v=v)
val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...")
val_lm_auroc_res = roc_auc_ci(
to_one_hot(val_gt_rep, k).long().flatten(), val_lm_preds.flatten()
)
row["lm_auroc"] = val_lm_auroc_res.estimate
row["lm_auroc_lower"] = val_lm_auroc_res.lower
row["lm_auroc_upper"] = val_lm_auroc_res.upper
row["lm_acc"] = accuracy(val_gt_rep, val_lm_preds)

if lr_model is not None:
row["lr_auroc"], row["lr_acc"] = evaluate_supervised(
lr_auroc_res, row["lr_acc"] = evaluate_supervised(
lr_model, val_h, val_gt
)
row["lr_auroc"] = lr_auroc_res.estimate
row["lr_auroc_lower"] = lr_auroc_res.lower
row["lr_auroc_upper"] = lr_auroc_res.upper

row_buf.append(row)

Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ dependencies = [
"pandas",
# Basically any version should work as long as it supports the user's CUDA version
"pynvml",
# Doesn't really matter but before 1.0.0 there might be weird breaking changes
"scikit-learn>=1.0.0",
# Needed for certain HF tokenizers
"sentencepiece==0.1.97",
# We upstreamed bugfixes for Literal types in 0.1.1
Expand All @@ -43,7 +41,8 @@ dev = [
"hypothesis",
"pre-commit",
"pytest",
"pyright"
"pyright",
"scikit-learn",
]

[project.scripts]
Expand Down
Loading

0 comments on commit 7b4a00c

Please sign in to comment.