forked from EleutherAI/elk
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cluster bootstrap for metrics; refactor metric computations into eval…
…uate_preds (EleutherAI#197) * Refactor metrics into evaluate_preds * Fix stupid CCS bug * Cluster bootstrap for AUROC; boost default sample size * Cluster bootstrap for accuracy * Allow for arbitrary hparam selection in sweep * Don't normalize LM probs twice * Fix normalization of LM logits
- Loading branch information
1 parent
3b4592c
commit 4d65f9c
Showing
14 changed files
with
294 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from .accuracy import accuracy_ci | ||
from .calibration import CalibrationError, CalibrationEstimate | ||
from .eval import EvalResult, evaluate_preds, to_one_hot | ||
from .roc_auc import RocAucResult, roc_auc, roc_auc_ci | ||
|
||
__all__ = [ | ||
"accuracy_ci", | ||
"CalibrationError", | ||
"CalibrationEstimate", | ||
"EvalResult", | ||
"evaluate_preds", | ||
"roc_auc", | ||
"roc_auc_ci", | ||
"to_one_hot", | ||
"RocAucResult", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
@dataclass(frozen=True) | ||
class AccuracyResult: | ||
"""Accuracy point estimate and confidence interval.""" | ||
|
||
estimate: float | ||
"""Point estimate of the accuracy computed on this sample.""" | ||
lower: float | ||
"""Lower bound of the confidence interval.""" | ||
upper: float | ||
"""Upper bound of the confidence interval.""" | ||
|
||
|
||
def accuracy_ci( | ||
y_true: Tensor, | ||
y_pred: Tensor, | ||
*, | ||
num_samples: int = 1000, | ||
level: float = 0.95, | ||
seed: int = 42, | ||
) -> AccuracyResult: | ||
"""Bootstrap confidence interval for accuracy, with optional clustering. | ||
When the input arguments are 2D, this function performs the cluster bootstrap, | ||
resampling clusters with replacement instead of individual samples. The first | ||
axis is assumed to be the cluster axis. | ||
Args: | ||
y_true: Ground truth tensor of shape `(N,)` or `(N, cluster_size)`. | ||
y_pred: Predicted class tensor of shape `(N,)` or `(N, cluster_size)`. | ||
num_samples (int): Number of bootstrap samples to use. | ||
level (float): Confidence level of the confidence interval. | ||
seed (int): Random seed for reproducibility. | ||
Returns: | ||
RocAucResult: Named tuple containing the lower and upper bounds of the | ||
confidence interval, along with the point estimate. | ||
""" | ||
if torch.is_floating_point(y_pred) or torch.is_floating_point(y_true): | ||
raise TypeError("y_true and y_pred should be integer tensors") | ||
if y_true.shape != y_pred.shape: | ||
raise ValueError( | ||
f"y_true and y_pred should have the same shape; " | ||
f"got {y_true.shape} and {y_pred.shape}" | ||
) | ||
if y_true.dim() not in (1, 2): | ||
raise ValueError("y_true and y_pred should be 1D or 2D tensors") | ||
|
||
# Either the number of samples (1D) or the number of clusters (2D) | ||
N = y_true.shape[0] | ||
device = y_true.device | ||
|
||
# Generate random indices for bootstrap samples (shape: [num_bootstraps, N]) | ||
rng = torch.Generator(device=device).manual_seed(seed) | ||
indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng) | ||
|
||
# Create bootstrap samples of true labels and predicted probabilities | ||
y_true_bootstraps = y_true[indices] | ||
y_pred_bootstraps = y_pred[indices] | ||
|
||
# Compute ROC AUC scores for bootstrap samples. If the inputs were 2D, the | ||
# bootstrapped tensors are now 3D [num_bootstraps, N, cluster_size], so we | ||
# call flatten(1) to get a 2D tensor [num_bootstraps, N * cluster_size]. | ||
bootstrap_hits = y_true_bootstraps.flatten(1).eq(y_pred_bootstraps.flatten(1)) | ||
bootstrap_accs = bootstrap_hits.float().mean(1) | ||
|
||
# Calculate the lower and upper bounds of the confidence interval. We use | ||
# nanquantile instead of quantile because some bootstrap samples may have | ||
# NaN values due to the fact that they have only one class. | ||
alpha = (1 - level) / 2 | ||
q = bootstrap_accs.new_tensor([alpha, 1 - alpha]) | ||
lower, upper = bootstrap_accs.nanquantile(q).tolist() | ||
|
||
# Compute the point estimate. Call flatten to ensure that we get a single number | ||
# computed across cluster boundaries even if the inputs were clustered. | ||
estimate = y_true.flatten().eq(y_pred.flatten()).float().mean().item() | ||
return AccuracyResult(estimate, lower, upper) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from dataclasses import asdict, dataclass | ||
|
||
import torch | ||
from einops import repeat | ||
from torch import Tensor | ||
|
||
from .accuracy import AccuracyResult, accuracy_ci | ||
from .calibration import CalibrationError, CalibrationEstimate | ||
from .roc_auc import RocAucResult, roc_auc_ci | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EvalResult: | ||
"""The result of evaluating a classifier.""" | ||
|
||
accuracy: AccuracyResult | ||
"""Top 1 accuracy, implemented for both binary and multi-class classification.""" | ||
cal_accuracy: AccuracyResult | None | ||
"""Calibrated accuracy, only implemented for binary classification.""" | ||
calibration: CalibrationEstimate | None | ||
"""Expected calibration error, only implemented for binary classification.""" | ||
roc_auc: RocAucResult | ||
"""Area under the ROC curve. For multi-class classification, each class is treated | ||
as a one-vs-rest binary classification problem.""" | ||
|
||
def to_dict(self, prefix: str = "") -> dict[str, float]: | ||
"""Convert the result to a dictionary.""" | ||
acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()} | ||
cal_acc_dict = ( | ||
{f"{prefix}cal_acc_{k}": v for k, v in asdict(self.cal_accuracy).items()} | ||
if self.cal_accuracy is not None | ||
else {} | ||
) | ||
cal_dict = ( | ||
{f"{prefix}ece": self.calibration.ece} | ||
if self.calibration is not None | ||
else {} | ||
) | ||
auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()} | ||
return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict} | ||
|
||
|
||
def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult: | ||
""" | ||
Evaluate the performance of a classification model. | ||
Args: | ||
y_true: Ground truth tensor of shape (N,). | ||
y_pred: Predicted class tensor of shape (N, variants, n_classes). | ||
Returns: | ||
dict: A dictionary containing the accuracy, AUROC, and ECE. | ||
""" | ||
(n, v, c) = y_logits.shape | ||
assert y_true.shape == (n,) | ||
|
||
# Clustered bootstrap confidence intervals for AUROC | ||
y_true = repeat(y_true, "n -> n v", v=v) | ||
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) | ||
acc = accuracy_ci(y_true, y_logits.argmax(dim=-1)) | ||
|
||
cal_acc = None | ||
cal_err = None | ||
|
||
if c == 2: | ||
pos_probs = y_logits.softmax(-1)[..., 1] | ||
|
||
# Calibrated accuracy | ||
cal_thresh = pos_probs.float().quantile(y_true.float().mean()) | ||
cal_preds = pos_probs.gt(cal_thresh).to(torch.int) | ||
cal_acc = accuracy_ci(y_true, cal_preds) | ||
|
||
cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) | ||
cal_err = cal.compute() | ||
|
||
return EvalResult(acc, cal_acc, cal_err, auroc) | ||
|
||
|
||
def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: | ||
""" | ||
Convert a tensor of class labels to a one-hot representation. | ||
Args: | ||
labels (Tensor): A tensor of class labels of shape (N,). | ||
n_classes (int): The total number of unique classes. | ||
Returns: | ||
Tensor: A one-hot representation tensor of shape (N, n_classes). | ||
""" | ||
one_hot_labels = labels.new_zeros(*labels.shape, n_classes) | ||
return one_hot_labels.scatter_(-1, labels.unsqueeze(-1).long(), 1) |
Oops, something went wrong.