Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for arbitrary hyperparameter selection for sweep #198

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Refactor metrics into evaluate_preds
  • Loading branch information
norabelrose committed Apr 17, 2023
commit 6eb18a493512bb33851e6b3fc116c1a84a21a17a
25 changes: 10 additions & 15 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from ..extraction.extraction import Extract
from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
from ..training.supervised import evaluate_supervised
from ..utils import select_usable_devices


Expand Down Expand Up @@ -70,30 +70,25 @@ def evaluate_reporter(

row_buf = []
for ds_name, (val_h, val_gt, _) in val_output.items():
val_result = reporter.score(val_gt, val_h)
val_result = evaluate_preds(val_gt, reporter(val_h))

stats_row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
**val_result._asdict(),
}
)
stats_row = {
"dataset": ds_name,
"layer": layer,
**val_result.to_dict(),
}

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)
stats_row["lr_auroc"] = lr_auroc_res.estimate
stats_row["lr_auroc_lower"] = lr_auroc_res.lower
stats_row["lr_auroc_upper"] = lr_auroc_res.upper
stats_row["lr_acc"] = lr_acc
lr_result = evaluate_preds(val_gt, lr_model(val_h))
stats_row.update(lr_result.to_dict(prefix="lr_"))

row_buf.append(stats_row)

return pd.DataFrame(row_buf)
return pd.DataFrame.from_records(row_buf)

def evaluate(self):
"""Evaluate the reporter on all layers."""
Expand Down
13 changes: 13 additions & 0 deletions elk/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .calibration import CalibrationError, CalibrationEstimate
from .eval import EvalResult, evaluate_preds, to_one_hot
from .roc_auc import RocAucResult, roc_auc_ci

__all__ = [
"CalibrationError",
"CalibrationEstimate",
"EvalResult",
"evaluate_preds",
"roc_auc_ci",
"to_one_hot",
"RocAucResult",
]
48 changes: 48 additions & 0 deletions elk/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dataclasses import dataclass

import torch
from torch import Tensor
from torch.distributions.normal import Normal


@dataclass(frozen=True)
class AccuracyResult:
"""Accuracy point estimate and confidence interval."""

estimate: float
"""Point estimate of the accuracy computed on this sample."""
lower: float
"""Lower bound of the confidence interval."""
upper: float
"""Upper bound of the confidence interval."""


def accuracy_ci(
y_true: Tensor, y_pred: Tensor, *, level: float = 0.95
) -> AccuracyResult:
"""
Compute the accuracy of a classifier and its confidence interval.

Args:
y_true: Ground truth tensor of shape (N,).
y_pred: Predicted class tensor of shape (N,).

Returns:
float: Accuracy of the model.
"""
# We expect the inputs to be integers
assert not torch.is_floating_point(y_pred) and not torch.is_floating_point(y_true)
assert y_true.shape == y_pred.shape

# Point estimate of the accuracy
acc = y_pred.eq(y_true).float().mean()

# Compute the CI quantiles
alpha = (1 - level) / 2
q = acc.new_tensor([alpha, 1 - alpha])

# Normal approximation to the binomial distribution
stderr = (acc * (1 - acc) / len(y_true)) ** 0.5
lower, upper = Normal(acc, stderr).icdf(q).tolist()

return AccuracyResult(acc.item(), lower, upper)
6 changes: 3 additions & 3 deletions elk/calibration.py → elk/metrics/calibration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import warnings
from dataclasses import dataclass, field
from typing import NamedTuple

import torch
from torch import Tensor


class CalibrationEstimate(NamedTuple):
@dataclass(frozen=True)
class CalibrationEstimate:
ece: float
num_bins: int

Expand Down Expand Up @@ -82,7 +82,7 @@ def compute(self, p: int = 2) -> CalibrationEstimate:
# Split into (nearly) equal mass bins. They won't be exactly equal, so we
# still weight the bins by their size.
conf_bins = pred_probs.tensor_split(b_star)
w = torch.tensor([len(c) / n for c in conf_bins])
w = pred_probs.new_tensor([len(c) / n for c in conf_bins])

# See the definition of ECE_sweep in Equation 8 of Roelofs et al. (2020)
mean_confs = torch.stack([c.mean() for c in conf_bins])
Expand Down
89 changes: 89 additions & 0 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from dataclasses import asdict, dataclass

import torch
from einops import rearrange, repeat
from torch import Tensor

from .accuracy import AccuracyResult, accuracy_ci
from .calibration import CalibrationError, CalibrationEstimate
from .roc_auc import RocAucResult, roc_auc_ci


@dataclass(frozen=True)
class EvalResult:
"""The result of evaluating a classifier."""

accuracy: AccuracyResult
"""Top 1 accuracy, implemented for both binary and multi-class classification."""
cal_accuracy: AccuracyResult | None
"""Calibrated accuracy, only implemented for binary classification."""
calibration: CalibrationEstimate | None
"""Expected calibration error, only implemented for binary classification."""
roc_auc: RocAucResult
"""Area under the ROC curve. For multi-class classification, each class is treated
as a one-vs-rest binary classification problem."""

def to_dict(self, prefix: str = "") -> dict[str, float]:
"""Convert the result to a dictionary."""
acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()}
cal_acc_dict = (
{f"{prefix}cal_acc_{k}": v for k, v in asdict(self.cal_accuracy).items()}
if self.cal_accuracy is not None
else {}
)
cal_dict = (
{f"{prefix}ece": self.calibration.ece}
if self.calibration is not None
else {}
)
auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()}
return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict}


def evaluate_preds(y_true: Tensor, y_pred: Tensor) -> EvalResult:
"""
Evaluate the performance of a classification model.

Args:
y_true: Ground truth tensor of shape (N,).
y_pred: Predicted class tensor of shape (N, variants, n_classes).

Returns:
dict: A dictionary containing the accuracy, AUROC, and ECE.
"""
(n, v, c) = y_pred.shape
assert y_true.shape == (n,)

y_pred = rearrange(y_pred, "n v c -> (n v) c")
y_true = repeat(y_true, "n -> (n v)", v=v)

acc = accuracy_ci(y_true, y_pred.argmax(dim=-1))
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(), y_pred.flatten())
cal_acc = None
cal_err = None

if c == 2:
pos_probs = y_pred[..., 1].flatten().sigmoid()
cal_err = CalibrationError().update(y_true, pos_probs).compute()

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(y_true.float().mean())
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = accuracy_ci(y_true, cal_preds)

return EvalResult(acc, cal_acc, cal_err, auroc)


def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
"""
Convert a tensor of class labels to a one-hot representation.

Args:
labels (Tensor): A tensor of class labels of shape (N,).
n_classes (int): The total number of unique classes.

Returns:
Tensor: A one-hot representation tensor of shape (N, n_classes).
"""
one_hot_labels = labels.new_zeros(labels.size(0), n_classes)
return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1)
42 changes: 4 additions & 38 deletions elk/metrics.py → elk/metrics/roc_auc.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,12 @@
from typing import NamedTuple
from dataclasses import dataclass

import torch
from torch import Tensor


def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
"""
Convert a tensor of class labels to a one-hot representation.

Args:
labels (Tensor): A tensor of class labels of shape (N,).
n_classes (int): The total number of unique classes.

Returns:
Tensor: A one-hot representation tensor of shape (N, n_classes).
"""
one_hot_labels = labels.new_zeros(labels.size(0), n_classes)
return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1)


def accuracy(y_true: Tensor, y_pred: Tensor) -> float:
"""
Compute the accuracy of a classification model.

Args:
y_true: Ground truth tensor of shape (N,).
y_pred: Predicted class tensor of shape (N,) or (N, n_classes).

Returns:
float: Accuracy of the model.
"""
# Check if binary or multi-class classification
if len(y_pred.shape) == 1:
hard_preds = y_pred > 0.5
else:
hard_preds = y_pred.argmax(-1)

return hard_preds.cpu().eq(y_true.cpu()).float().mean().item()


class RocAucResult(NamedTuple):
"""Named tuple for storing ROC AUC results."""
@dataclass(frozen=True)
class RocAucResult:
"""Dataclass for storing ROC AUC results."""

estimate: float
"""Point estimate of the ROC AUC computed on this sample."""
Expand Down
4 changes: 2 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def prepare_data(
lm_preds = split["model_preds"] if has_preds else None

ds_name = get_dataset_name(ds)
out[ds_name] = (val_h, labels, lm_preds)
out[ds_name] = (val_h, labels.to(val_h.device), lm_preds)

return out

Expand Down Expand Up @@ -148,6 +148,6 @@ def apply_to_layers(
# Make sure the CSV is written even if we crash or get interrupted
if df_buf:
df = pd.concat(df_buf).sort_values(by="layer")
df.to_csv(f, index=False)
df.round(4).to_csv(f, index=False)
if self.cfg.debug:
save_debug_log(self.datasets, self.out_dir)
2 changes: 1 addition & 1 deletion elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self.linear.weight.data.zero_()

def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
return self.linear(x).squeeze(-1)

@torch.enable_grad()
def fit(
Expand Down
74 changes: 1 addition & 73 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, NamedTuple, Optional
from typing import Literal, Optional

import torch
import torch.nn as nn
from einops import rearrange, repeat
from simple_parsing.helpers import Serializable
from torch import Tensor

from ..calibration import CalibrationError
from ..metrics import accuracy, roc_auc_ci, to_one_hot


class EvalResult(NamedTuple):
"""The result of evaluating a reporter on a dataset.

The `.score()` function of a reporter returns an instance of this class,
which contains the loss, accuracy, calibrated accuracy, and AUROC.
"""

auroc: float
auroc_lower: float
auroc_upper: float

acc: float
cal_acc: float
ece: float


@dataclass
class ReporterConfig(Serializable):
Expand Down Expand Up @@ -83,55 +63,3 @@ def fit(
labels: Optional[Tensor] = None,
) -> float:
...

@torch.no_grad()
def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
"""Score the probe on the contrast set `hiddens`.

Args:
labels: The labels of the contrast pair.
hiddens: Contrast set of shape [n, v, k, d].

Returns:
an instance of EvalResult containing the loss, accuracy, calibrated
accuracy, and AUROC of the probe on `contrast_set`.
Accuracy: top-1 accuracy averaged over questions and variants.
Calibrated accuracy: top-1 accuracy averaged over questions and
variants, calibrated so that x% of the predictions are `True`,
where x is the proprtion of examples with ground truth label `True`.
AUROC: averaged over the n * v * c binary questions
ECE: Expected Calibration Error
"""
logits = self(hiddens)
(_, v, c) = logits.shape

# makes `num_variants` copies of each label
logits = rearrange(logits, "n v c -> (n v) c")
Y = repeat(labels, "n -> (n v)", v=v).float()

if c == 2:
pos_probs = logits[..., 1].flatten().sigmoid()
cal_err = CalibrationError().update(Y.cpu(), pos_probs.cpu()).compute().ece

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(labels.float().mean())
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = cal_preds.flatten().eq(Y).float().mean().item()
else:
# TODO: Implement calibration error for k > 2?
cal_acc = 0.0
cal_err = 0.0

Y_one_hot = to_one_hot(Y, c).long().flatten()
auroc_result = roc_auc_ci(Y_one_hot, logits.flatten())

raw_preds = logits.argmax(dim=-1).long()
raw_acc = accuracy(Y, raw_preds.flatten())
return EvalResult(
auroc=auroc_result.estimate,
auroc_lower=auroc_result.lower,
auroc_upper=auroc_result.upper,
acc=float(raw_acc),
cal_acc=cal_acc,
ece=cal_err,
)
Loading