Skip to content

Commit

Permalink
Cluster bootstrap for metrics; refactor metric computations into eval…
Browse files Browse the repository at this point in the history
…uate_preds (EleutherAI#197)

* Refactor metrics into evaluate_preds

* Fix stupid CCS bug

* Cluster bootstrap for AUROC; boost default sample size

* Cluster bootstrap for accuracy

* Allow for arbitrary hparam selection in sweep

* Don't normalize LM probs twice

* Fix normalization of LM logits
  • Loading branch information
norabelrose committed Apr 19, 2023
1 parent 3b4592c commit 4d65f9c
Show file tree
Hide file tree
Showing 14 changed files with 294 additions and 213 deletions.
25 changes: 10 additions & 15 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from ..extraction.extraction import Extract
from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
from ..training.supervised import evaluate_supervised
from ..utils import select_usable_devices


Expand Down Expand Up @@ -70,30 +70,25 @@ def evaluate_reporter(

row_buf = []
for ds_name, (val_h, val_gt, _) in val_output.items():
val_result = reporter.score(val_gt, val_h)
val_result = evaluate_preds(val_gt, reporter(val_h))

stats_row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
**val_result._asdict(),
}
)
stats_row = {
"dataset": ds_name,
"layer": layer,
**val_result.to_dict(),
}

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)
stats_row["lr_auroc"] = lr_auroc_res.estimate
stats_row["lr_auroc_lower"] = lr_auroc_res.lower
stats_row["lr_auroc_upper"] = lr_auroc_res.upper
stats_row["lr_acc"] = lr_acc
lr_result = evaluate_preds(val_gt, lr_model(val_h))
stats_row.update(lr_result.to_dict(prefix="lr_"))

row_buf.append(stats_row)

return pd.DataFrame(row_buf)
return pd.DataFrame.from_records(row_buf)

def evaluate(self):
"""Evaluate the reporter on all layers."""
Expand Down
12 changes: 6 additions & 6 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def extract_hiddens(
)
for layer_idx in layer_indices
}
lm_preds = torch.empty(
lm_logits = torch.empty(
num_variants,
num_choices,
device=device,
Expand Down Expand Up @@ -205,13 +205,13 @@ def extract_hiddens(
dim=-1
)
tokens = inputs.input_ids[..., start:end, None]
lm_preds[i, j] = log_p.gather(-1, tokens).sum()
lm_logits[i, j] = log_p.gather(-1, tokens).sum()

elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = inputs.labels.shape[-1]
lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length

hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
Expand Down Expand Up @@ -244,7 +244,7 @@ def extract_hiddens(
**hidden_dict,
)
if has_lm_preds:
out_record["model_preds"] = lm_preds.softmax(dim=-1)
out_record["model_logits"] = lm_logits

yield out_record

Expand Down Expand Up @@ -319,9 +319,9 @@ def get_splits() -> SplitDict:
),
}

# Only add model_preds if the model is an autoregressive model
# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg):
other_cols["model_preds"] = Array2D(
other_cols["model_logits"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
)
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PromptConfig(Serializable):
datasets: list[str] = field(positional=True)
data_dirs: list[str] = field(default_factory=list)
label_columns: list[str] = field(default_factory=list)
max_examples: list[int] = field(default_factory=lambda: [750, 250])
max_examples: list[int] = field(default_factory=lambda: [1000, 1000])
num_classes: int = 0
num_shots: int = 0
num_variants: int = -1
Expand Down
16 changes: 16 additions & 0 deletions elk/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .accuracy import accuracy_ci
from .calibration import CalibrationError, CalibrationEstimate
from .eval import EvalResult, evaluate_preds, to_one_hot
from .roc_auc import RocAucResult, roc_auc, roc_auc_ci

__all__ = [
"accuracy_ci",
"CalibrationError",
"CalibrationEstimate",
"EvalResult",
"evaluate_preds",
"roc_auc",
"roc_auc_ci",
"to_one_hot",
"RocAucResult",
]
82 changes: 82 additions & 0 deletions elk/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from dataclasses import dataclass

import torch
from torch import Tensor


@dataclass(frozen=True)
class AccuracyResult:
"""Accuracy point estimate and confidence interval."""

estimate: float
"""Point estimate of the accuracy computed on this sample."""
lower: float
"""Lower bound of the confidence interval."""
upper: float
"""Upper bound of the confidence interval."""


def accuracy_ci(
y_true: Tensor,
y_pred: Tensor,
*,
num_samples: int = 1000,
level: float = 0.95,
seed: int = 42,
) -> AccuracyResult:
"""Bootstrap confidence interval for accuracy, with optional clustering.
When the input arguments are 2D, this function performs the cluster bootstrap,
resampling clusters with replacement instead of individual samples. The first
axis is assumed to be the cluster axis.
Args:
y_true: Ground truth tensor of shape `(N,)` or `(N, cluster_size)`.
y_pred: Predicted class tensor of shape `(N,)` or `(N, cluster_size)`.
num_samples (int): Number of bootstrap samples to use.
level (float): Confidence level of the confidence interval.
seed (int): Random seed for reproducibility.
Returns:
RocAucResult: Named tuple containing the lower and upper bounds of the
confidence interval, along with the point estimate.
"""
if torch.is_floating_point(y_pred) or torch.is_floating_point(y_true):
raise TypeError("y_true and y_pred should be integer tensors")
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")

# Either the number of samples (1D) or the number of clusters (2D)
N = y_true.shape[0]
device = y_true.device

# Generate random indices for bootstrap samples (shape: [num_bootstraps, N])
rng = torch.Generator(device=device).manual_seed(seed)
indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng)

# Create bootstrap samples of true labels and predicted probabilities
y_true_bootstraps = y_true[indices]
y_pred_bootstraps = y_pred[indices]

# Compute ROC AUC scores for bootstrap samples. If the inputs were 2D, the
# bootstrapped tensors are now 3D [num_bootstraps, N, cluster_size], so we
# call flatten(1) to get a 2D tensor [num_bootstraps, N * cluster_size].
bootstrap_hits = y_true_bootstraps.flatten(1).eq(y_pred_bootstraps.flatten(1))
bootstrap_accs = bootstrap_hits.float().mean(1)

# Calculate the lower and upper bounds of the confidence interval. We use
# nanquantile instead of quantile because some bootstrap samples may have
# NaN values due to the fact that they have only one class.
alpha = (1 - level) / 2
q = bootstrap_accs.new_tensor([alpha, 1 - alpha])
lower, upper = bootstrap_accs.nanquantile(q).tolist()

# Compute the point estimate. Call flatten to ensure that we get a single number
# computed across cluster boundaries even if the inputs were clustered.
estimate = y_true.flatten().eq(y_pred.flatten()).float().mean().item()
return AccuracyResult(estimate, lower, upper)
6 changes: 3 additions & 3 deletions elk/calibration.py → elk/metrics/calibration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import warnings
from dataclasses import dataclass, field
from typing import NamedTuple

import torch
from torch import Tensor


class CalibrationEstimate(NamedTuple):
@dataclass(frozen=True)
class CalibrationEstimate:
ece: float
num_bins: int

Expand Down Expand Up @@ -82,7 +82,7 @@ def compute(self, p: int = 2) -> CalibrationEstimate:
# Split into (nearly) equal mass bins. They won't be exactly equal, so we
# still weight the bins by their size.
conf_bins = pred_probs.tensor_split(b_star)
w = torch.tensor([len(c) / n for c in conf_bins])
w = pred_probs.new_tensor([len(c) / n for c in conf_bins])

# See the definition of ECE_sweep in Equation 8 of Roelofs et al. (2020)
mean_confs = torch.stack([c.mean() for c in conf_bins])
Expand Down
91 changes: 91 additions & 0 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from dataclasses import asdict, dataclass

import torch
from einops import repeat
from torch import Tensor

from .accuracy import AccuracyResult, accuracy_ci
from .calibration import CalibrationError, CalibrationEstimate
from .roc_auc import RocAucResult, roc_auc_ci


@dataclass(frozen=True)
class EvalResult:
"""The result of evaluating a classifier."""

accuracy: AccuracyResult
"""Top 1 accuracy, implemented for both binary and multi-class classification."""
cal_accuracy: AccuracyResult | None
"""Calibrated accuracy, only implemented for binary classification."""
calibration: CalibrationEstimate | None
"""Expected calibration error, only implemented for binary classification."""
roc_auc: RocAucResult
"""Area under the ROC curve. For multi-class classification, each class is treated
as a one-vs-rest binary classification problem."""

def to_dict(self, prefix: str = "") -> dict[str, float]:
"""Convert the result to a dictionary."""
acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()}
cal_acc_dict = (
{f"{prefix}cal_acc_{k}": v for k, v in asdict(self.cal_accuracy).items()}
if self.cal_accuracy is not None
else {}
)
cal_dict = (
{f"{prefix}ece": self.calibration.ece}
if self.calibration is not None
else {}
)
auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()}
return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict}


def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult:
"""
Evaluate the performance of a classification model.
Args:
y_true: Ground truth tensor of shape (N,).
y_pred: Predicted class tensor of shape (N, variants, n_classes).
Returns:
dict: A dictionary containing the accuracy, AUROC, and ECE.
"""
(n, v, c) = y_logits.shape
assert y_true.shape == (n,)

# Clustered bootstrap confidence intervals for AUROC
y_true = repeat(y_true, "n -> n v", v=v)
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
acc = accuracy_ci(y_true, y_logits.argmax(dim=-1))

cal_acc = None
cal_err = None

if c == 2:
pos_probs = y_logits.softmax(-1)[..., 1]

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(y_true.float().mean())
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = accuracy_ci(y_true, cal_preds)

cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten())
cal_err = cal.compute()

return EvalResult(acc, cal_acc, cal_err, auroc)


def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
"""
Convert a tensor of class labels to a one-hot representation.
Args:
labels (Tensor): A tensor of class labels of shape (N,).
n_classes (int): The total number of unique classes.
Returns:
Tensor: A one-hot representation tensor of shape (N, n_classes).
"""
one_hot_labels = labels.new_zeros(*labels.shape, n_classes)
return one_hot_labels.scatter_(-1, labels.unsqueeze(-1).long(), 1)

0 comments on commit 4d65f9c

Please sign in to comment.