Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blazing fast bootstrap stderrs for AUROC #190

Merged
merged 58 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d292c7c
LM output evaluation for autoregressive models
norabelrose Apr 4, 2023
7ed5ccd
move to own baseline file
lauritowal Apr 4, 2023
ba1d3b2
cleanup
lauritowal Apr 4, 2023
a20d4ca
Support encoder-decoder model LM output
norabelrose Apr 5, 2023
088758e
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 5, 2023
77d7418
isort
norabelrose Apr 5, 2023
5bf63f4
Bug fixes
norabelrose Apr 5, 2023
819cfed
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
d3d9a8d
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
b89e23c
Remove test_log_csv_elements
norabelrose Apr 5, 2023
9aef842
Remove Python 3.9 support
norabelrose Apr 5, 2023
0851d4f
Add Pandas to pyproject.toml
norabelrose Apr 5, 2023
207a375
add code (contains still same device cuda error)
lauritowal Apr 5, 2023
e7efcce
fix multiple cuda error, save evals to right folder + cleanup
lauritowal Apr 7, 2023
b5fa54c
Merge branch 'main' into eval_lr
lauritowal Apr 7, 2023
4f8bdc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9ca72ba
Fix bug noticed by Waree
norabelrose Apr 7, 2023
d7e4893
Merge remote-tracking branch 'origin/eval_lr' into lm-output
norabelrose Apr 7, 2023
bcdca8a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 7, 2023
713a251
Add sanity check to load_prompts and refactor binarize
norabelrose Apr 7, 2023
0c35bc7
Changing a ton of stuff
norabelrose Apr 8, 2023
f6a762a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f547744
Revert changes to binarize
norabelrose Apr 10, 2023
ab1909f
Stupid prompt_counter bug
norabelrose Apr 10, 2023
f58290f
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f912ee6
Remove stupid second set_start_method call
norabelrose Apr 10, 2023
606dcad
Merge remote-tracking branch 'origin/lm-output' into multiclass
norabelrose Apr 10, 2023
0038792
Merge remote-tracking branch 'origin/main' into multiclass
norabelrose Apr 10, 2023
83b480b
Fix bugs in binary case
norabelrose Apr 11, 2023
3e66262
Various little refactors
norabelrose Apr 11, 2023
a8c21a6
Remove .predict and .predict_prob on Reporter; trying to get SciQ to …
norabelrose Apr 11, 2023
5f478b1
Bugfix for Reporter.score on binary tasks
norabelrose Apr 11, 2023
97b26ac
Fix bug where cached hidden states aren’t used when num_gpus is diffe…
norabelrose Apr 12, 2023
11fda87
Actually works now
norabelrose Apr 12, 2023
da4c72f
Refactor handling of multiple datasets
norabelrose Apr 13, 2023
e1675f7
Various fixes
norabelrose Apr 13, 2023
8cc325b
Merge remote-tracking branch 'origin/main' into multi-ds-eval
norabelrose Apr 13, 2023
14987e1
Fix math tests
norabelrose Apr 13, 2023
88683fa
Fix smoke tests
norabelrose Apr 13, 2023
a6c382e
All tests working ostensibly
norabelrose Apr 13, 2023
ecc53cb
Make CCS normalization customizable
norabelrose Apr 13, 2023
18c7f4c
log each dataset individually
AlexTMallen Apr 13, 2023
94a900c
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 13, 2023
5173649
Fix label_column bug
norabelrose Apr 13, 2023
3e6c39c
GLUE MNLI works on Deberta
norabelrose Apr 14, 2023
1e9ce06
Move pseudo AUROC stuff to CcsReporter
norabelrose Apr 14, 2023
35a8f34
Make 'datasets' and 'label_columns' config options more opinionated
norabelrose Apr 14, 2023
615bbb1
tiny spacing change
norabelrose Apr 14, 2023
f021404
Allow for toggling CV
norabelrose Apr 14, 2023
f6629ec
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 14, 2023
99f01c3
Remove duplicate dbpedia template
norabelrose Apr 14, 2023
f415f8d
Merge branch 'main' into multiclass
norabelrose Apr 14, 2023
d16c96b
Training on datasets with different numbers of classes now works
norabelrose Apr 15, 2023
044774e
Efficient bootstrap CIs for AUROCs
norabelrose Apr 15, 2023
a7f1ea0
Fix CCS smoke test failure
norabelrose Apr 15, 2023
3abeb60
Update extraction.py
lauritowal Apr 16, 2023
1e4a6b9
Merge branch 'main' into roc_auc
lauritowal Apr 16, 2023
4c60061
Update extraction.py
lauritowal Apr 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
args: ["--skip=*.yaml"]
args: ["-L fpr", "--skip=*.yaml"]
7 changes: 4 additions & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def evaluate_reporter(
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)

stats_row["lr_auroc"] = lr_auroc
lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)
stats_row["lr_auroc"] = lr_auroc_res.estimate
stats_row["lr_auroc_lower"] = lr_auroc_res.lower
stats_row["lr_auroc_upper"] = lr_auroc_res.upper
stats_row["lr_acc"] = lr_acc

row_buf.append(stats_row)
Expand Down
122 changes: 122 additions & 0 deletions elk/metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import NamedTuple

import torch
from torch import Tensor


Expand Down Expand Up @@ -34,3 +37,122 @@ def accuracy(y_true: Tensor, y_pred: Tensor) -> float:
hard_preds = y_pred.argmax(-1)

return hard_preds.cpu().eq(y_true.cpu()).float().mean().item()


class RocAucResult(NamedTuple):
"""Named tuple for storing ROC AUC results."""

estimate: float
"""Point estimate of the ROC AUC computed on this sample."""
lower: float
"""Lower bound of the bootstrap confidence interval."""
upper: float
"""Upper bound of the bootstrap confidence interval."""


def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor:
"""Area under the receiver operating characteristic curve (ROC AUC).

Unlike scikit-learn's implementation, this function supports batched inputs of
shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples
within each dataset. This is primarily useful for efficiently computing bootstrap
confidence intervals.

Args:
y_true: Ground truth tensor of shape `(N,)` or `(N, n)`.
y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`.

Returns:
Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D,
a tensor of shape (N,) containing the ROC AUC for each dataset.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")

# Sort y_pred in descending order and get indices
indices = y_pred.argsort(descending=True, dim=-1)

# Reorder y_true based on sorted y_pred indices
y_true_sorted = y_true.gather(-1, indices)

# Calculate number of positive and negative samples
num_positives = y_true.sum(dim=-1)
num_negatives = y_true.shape[-1] - num_positives

# Calculate cumulative sum of true positive counts (TPs)
tps = torch.cumsum(y_true_sorted, dim=-1)

# Calculate cumulative sum of false positive counts (FPs)
fps = torch.cumsum(1 - y_true_sorted, dim=-1)

# Calculate true positive rate (TPR) and false positive rate (FPR)
tpr = tps / num_positives.view(-1, 1)
fpr = fps / num_negatives.view(-1, 1)

# Calculate differences between consecutive FPR values (widths of trapezoids)
fpr_diffs = torch.cat(
[fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1
)

# Calculate area under the ROC curve for each dataset using trapezoidal rule
return torch.sum(tpr * fpr_diffs, dim=-1).squeeze()


def roc_auc_ci(
y_true: Tensor,
y_pred: Tensor,
*,
num_samples: int = 1000,
level: float = 0.95,
seed: int = 42,
) -> RocAucResult:
"""Bootstrap confidence interval for the ROC AUC.

Args:
y_true: Ground truth tensor of shape `(N,)`.
y_pred: Predicted class tensor of shape `(N,)`.
num_samples (int): Number of bootstrap samples to use.
level (float): Confidence level of the confidence interval.
seed (int): Random seed for reproducibility.

Returns:
RocAucResult: Named tuple containing the lower and upper bounds of the
confidence interval, along with the point estimate.
"""
if y_true.shape != y_pred.shape:
raise ValueError(
f"y_true and y_pred should have the same shape; "
f"got {y_true.shape} and {y_pred.shape}"
)
if y_true.dim() != 1:
raise ValueError("y_true and y_pred should be 1D tensors")

device = y_true.device
N = y_true.shape[0]

# Generate random indices for bootstrap samples (shape: [num_bootstraps, N])
rng = torch.Generator(device=device).manual_seed(seed)
indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng)

# Create bootstrap samples of true labels and predicted probabilities
y_true_bootstraps = y_true[indices]
y_pred_bootstraps = y_pred[indices]

# Compute ROC AUC scores for bootstrap samples
bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps)

# Calculate the lower and upper bounds of the confidence interval. We use
# nanquantile instead of quantile because some bootstrap samples may have
# NaN values due to the fact that they have only one class.
alpha = (1 - level) / 2
q = y_pred.new_tensor([alpha, 1 - alpha])
lower, upper = bootstrap_aucs.nanquantile(q).tolist()

# Compute the point estimate
estimate = roc_auc(y_true, y_pred).item()
return RocAucResult(estimate, lower, upper)
4 changes: 2 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_device(self, devices, world_size: int) -> str:

def prepare_data(
self, device: str, layer: int, split_type: Literal["train", "val"]
) -> dict[str, tuple[Tensor, Tensor, np.ndarray | None]]:
) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]:
"""Prepare data for the specified layer and split type."""
out = {}

Expand All @@ -98,7 +98,7 @@ def prepare_data(
labels = assert_type(Tensor, split["label"])
val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))

with split.formatted_as("numpy"):
with split.formatted_as("torch", device=device):
has_preds = "model_preds" in split.features
lm_preds = split["model_preds"] if has_preds else None

Expand Down
6 changes: 3 additions & 3 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch import Tensor
from torch.nn.functional import binary_cross_entropy as bce

from ..metrics import roc_auc
from ..parsing import parse_loss
from ..utils.typing import assert_type
from .classifier import Classifier
Expand Down Expand Up @@ -175,8 +175,8 @@ def check_separability(
pseudo_preds = pseudo_clf(
# b v d -> (b v) d
torch.cat([val_x0, val_x1]).flatten(0, 1)
)
return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu()))
).squeeze(-1)
return roc_auc(pseudo_val_labels, pseudo_preds).item()

def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
loss = sum(
Expand Down
15 changes: 9 additions & 6 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import torch.nn as nn
from einops import rearrange, repeat
from simple_parsing.helpers import Serializable
from sklearn.metrics import roc_auc_score
from torch import Tensor

from ..calibration import CalibrationError
from ..metrics import accuracy, to_one_hot
from ..metrics import accuracy, roc_auc_ci, to_one_hot


class EvalResult(NamedTuple):
Expand All @@ -23,9 +22,12 @@ class EvalResult(NamedTuple):
which contains the loss, accuracy, calibrated accuracy, and AUROC.
"""

auroc: float
auroc_lower: float
auroc_upper: float

acc: float
cal_acc: float
auroc: float
ece: float


Expand Down Expand Up @@ -117,12 +119,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
raw_preds = to_one_hot(logits.argmax(dim=-1), c).long()
Y = to_one_hot(Y, c).long().flatten()

auroc = roc_auc_score(Y.cpu(), logits.cpu().flatten())
raw_acc = accuracy(Y, raw_preds.flatten())

auroc_result = roc_auc_ci(Y, logits.flatten())
return EvalResult(
auroc=auroc_result.estimate,
auroc_lower=auroc_result.lower,
auroc_upper=auroc_result.upper,
acc=float(raw_acc),
cal_acc=cal_acc,
auroc=float(auroc),
ece=cal_err,
)
13 changes: 6 additions & 7 deletions elk/training/supervised.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import torch
from einops import rearrange, repeat
from sklearn.metrics import roc_auc_score
from torch import Tensor

from ..metrics import accuracy, to_one_hot
from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot
from ..utils import assert_type
from .classifier import Classifier


def evaluate_supervised(
lr_model: Classifier, val_h: Tensor, val_labels: Tensor
) -> tuple[float, float]:
(n, v, k, d) = val_h.shape
) -> tuple[RocAucResult, float]:
(_, v, k, _) = val_h.shape

with torch.no_grad():
logits = rearrange(lr_model(val_h).cpu().squeeze(), "n v k -> (n v) k")
logits = rearrange(lr_model(val_h).squeeze(), "n v k -> (n v) k")
raw_preds = to_one_hot(logits.argmax(dim=-1), k).long()

labels = repeat(val_labels, "n -> (n v)", v=v)
labels = to_one_hot(labels, k).flatten()

lr_acc = accuracy(labels, raw_preds.flatten())
lr_auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten())
lr_auroc = roc_auc_ci(labels, logits.flatten())

return assert_type(float, lr_auroc), assert_type(float, lr_acc)
return lr_auroc, assert_type(float, lr_acc)


def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier:
Expand Down
38 changes: 18 additions & 20 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import torch
from einops import rearrange, repeat
from simple_parsing import Serializable, field, subgroups
from sklearn.metrics import roc_auc_score

from ..extraction.extraction import Extract
from ..metrics import accuracy, to_one_hot
from ..metrics import accuracy, roc_auc_ci, to_one_hot
from ..run import Run
from ..training.supervised import evaluate_supervised, train_supervised
from ..utils import select_usable_devices
Expand Down Expand Up @@ -148,37 +147,36 @@ def train_reporter(
row_buf = []
for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items():
val_result = reporter.score(val_gt, val_h)

if val_lm_preds is not None:
(_, v, k, _) = val_h.shape

val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu()
val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...")
val_lm_auroc = roc_auc_score(
to_one_hot(val_gt_cpu, k).long().flatten(), val_lm_preds.flatten()
)

val_lm_acc = accuracy(val_gt_cpu, torch.from_numpy(val_lm_preds))
else:
val_lm_auroc = None
val_lm_acc = None

row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
**val_result._asdict(),
"lm_auroc": val_lm_auroc,
"lm_acc": val_lm_acc,
}
)

if val_lm_preds is not None:
(_, v, k, _) = val_h.shape

val_gt_rep = repeat(val_gt, "n -> (n v)", v=v)
val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...")
val_lm_auroc_res = roc_auc_ci(
to_one_hot(val_gt_rep, k).long().flatten(), val_lm_preds.flatten()
)
row["lm_auroc"] = val_lm_auroc_res.estimate
row["lm_auroc_lower"] = val_lm_auroc_res.lower
row["lm_auroc_upper"] = val_lm_auroc_res.upper
row["lm_acc"] = accuracy(val_gt_rep, val_lm_preds)

if lr_model is not None:
row["lr_auroc"], row["lr_acc"] = evaluate_supervised(
lr_auroc_res, row["lr_acc"] = evaluate_supervised(
lr_model, val_h, val_gt
)
row["lr_auroc"] = lr_auroc_res.estimate
row["lr_auroc_lower"] = lr_auroc_res.lower
row["lr_auroc_upper"] = lr_auroc_res.upper

row_buf.append(row)

Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ dependencies = [
"pandas",
# Basically any version should work as long as it supports the user's CUDA version
"pynvml",
# Doesn't really matter but before 1.0.0 there might be weird breaking changes
"scikit-learn>=1.0.0",
# Needed for certain HF tokenizers
"sentencepiece==0.1.97",
# We upstreamed bugfixes for Literal types in 0.1.1
Expand All @@ -43,7 +41,8 @@ dev = [
"hypothesis",
"pre-commit",
"pytest",
"pyright"
"pyright",
"scikit-learn",
]

[project.scripts]
Expand Down
Loading