Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster bootstrap for metrics; refactor metric computations into evaluate_preds #197

Merged
merged 13 commits into from
Apr 19, 2023
Merged
Prev Previous commit
Next Next commit
Cluster bootstrap for AUROC; boost default sample size
  • Loading branch information
norabelrose committed Apr 17, 2023
commit 14d13230529e7c9d9838bf98523e0eb4b3d10945
2 changes: 1 addition & 1 deletion elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PromptConfig(Serializable):
datasets: list[str] = field(positional=True)
data_dirs: list[str] = field(default_factory=list)
label_columns: list[str] = field(default_factory=list)
max_examples: list[int] = field(default_factory=lambda: [750, 250])
max_examples: list[int] = field(default_factory=lambda: [1000, 1000])
num_classes: int = 0
num_shots: int = 0
num_variants: int = -1
Expand Down
11 changes: 7 additions & 4 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ def evaluate_preds(y_true: Tensor, y_pred: Tensor) -> EvalResult:
(n, v, c) = y_pred.shape
assert y_true.shape == (n,)

# Clustered bootstrap confidence intervals for AUROC
y_true = repeat(y_true, "n -> n v", v=v)
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_pred.flatten(1))

y_pred = rearrange(y_pred, "n v c -> (n v) c")
y_true = repeat(y_true, "n -> (n v)", v=v)
y_true = y_true.flatten()

acc = accuracy_ci(y_true, y_pred.argmax(dim=-1))
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(), y_pred.flatten())
cal_acc = None
cal_err = None

Expand All @@ -85,5 +88,5 @@ def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
Returns:
Tensor: A one-hot representation tensor of shape (N, n_classes).
"""
one_hot_labels = labels.new_zeros(labels.size(0), n_classes)
return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1)
one_hot_labels = labels.new_zeros(*labels.shape, n_classes)
return one_hot_labels.scatter_(-1, labels.unsqueeze(-1).long(), 1)
28 changes: 18 additions & 10 deletions elk/metrics/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,15 @@ def roc_auc_ci(
level: float = 0.95,
seed: int = 42,
) -> RocAucResult:
"""Bootstrap confidence interval for the ROC AUC.
"""Bootstrap confidence interval for the ROC AUC, with optional clustering.

When the input arguments are 2D, this function performs the cluster bootstrap,
resampling clusters with replacement instead of individual samples. The first
axis is assumed to be the cluster axis.

Args:
y_true: Ground truth tensor of shape `(N,)`.
y_pred: Predicted class tensor of shape `(N,)`.
y_true: Ground truth tensor of shape `(N,)` or `(N, cluster_size)`.
y_pred: Predicted class tensor of shape `(N,)` or `(N, cluster_size)`.
num_samples (int): Number of bootstrap samples to use.
level (float): Confidence level of the confidence interval.
seed (int): Random seed for reproducibility.
Expand All @@ -95,11 +99,12 @@ def roc_auc_ci(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() != 1:
raise ValueError("y_true and y_pred should be 1D tensors")
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")

device = y_true.device
# Either the number of samples (1D) or the number of clusters (2D)
N = y_true.shape[0]
device = y_true.device

# Generate random indices for bootstrap samples (shape: [num_bootstraps, N])
rng = torch.Generator(device=device).manual_seed(seed)
Expand All @@ -109,8 +114,10 @@ def roc_auc_ci(
y_true_bootstraps = y_true[indices]
y_pred_bootstraps = y_pred[indices]

# Compute ROC AUC scores for bootstrap samples
bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps)
# Compute ROC AUC scores for bootstrap samples. If the inputs were 2D, the
# bootstrapped tensors are now 3D [num_bootstraps, N, cluster_size], so we
# call flatten(1) to get a 2D tensor [num_bootstraps, N * cluster_size].
bootstrap_aucs = roc_auc(y_true_bootstraps.flatten(1), y_pred_bootstraps.flatten(1))

# Calculate the lower and upper bounds of the confidence interval. We use
# nanquantile instead of quantile because some bootstrap samples may have
Expand All @@ -119,6 +126,7 @@ def roc_auc_ci(
q = y_pred.new_tensor([alpha, 1 - alpha])
lower, upper = bootstrap_aucs.nanquantile(q).tolist()

# Compute the point estimate
estimate = roc_auc(y_true, y_pred).item()
# Compute the point estimate. Call flatten to ensure that we get a single number
# computed across cluster boundaries even if the inputs were clustered.
estimate = roc_auc(y_true.flatten(), y_pred.flatten()).item()
return RocAucResult(estimate, lower, upper)