Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blazing fast bootstrap stderrs for AUROC #190

Merged
merged 58 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d292c7c
LM output evaluation for autoregressive models
norabelrose Apr 4, 2023
7ed5ccd
move to own baseline file
lauritowal Apr 4, 2023
ba1d3b2
cleanup
lauritowal Apr 4, 2023
a20d4ca
Support encoder-decoder model LM output
norabelrose Apr 5, 2023
088758e
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 5, 2023
77d7418
isort
norabelrose Apr 5, 2023
5bf63f4
Bug fixes
norabelrose Apr 5, 2023
819cfed
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
d3d9a8d
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
b89e23c
Remove test_log_csv_elements
norabelrose Apr 5, 2023
9aef842
Remove Python 3.9 support
norabelrose Apr 5, 2023
0851d4f
Add Pandas to pyproject.toml
norabelrose Apr 5, 2023
207a375
add code (contains still same device cuda error)
lauritowal Apr 5, 2023
e7efcce
fix multiple cuda error, save evals to right folder + cleanup
lauritowal Apr 7, 2023
b5fa54c
Merge branch 'main' into eval_lr
lauritowal Apr 7, 2023
4f8bdc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9ca72ba
Fix bug noticed by Waree
norabelrose Apr 7, 2023
d7e4893
Merge remote-tracking branch 'origin/eval_lr' into lm-output
norabelrose Apr 7, 2023
bcdca8a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 7, 2023
713a251
Add sanity check to load_prompts and refactor binarize
norabelrose Apr 7, 2023
0c35bc7
Changing a ton of stuff
norabelrose Apr 8, 2023
f6a762a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f547744
Revert changes to binarize
norabelrose Apr 10, 2023
ab1909f
Stupid prompt_counter bug
norabelrose Apr 10, 2023
f58290f
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f912ee6
Remove stupid second set_start_method call
norabelrose Apr 10, 2023
606dcad
Merge remote-tracking branch 'origin/lm-output' into multiclass
norabelrose Apr 10, 2023
0038792
Merge remote-tracking branch 'origin/main' into multiclass
norabelrose Apr 10, 2023
83b480b
Fix bugs in binary case
norabelrose Apr 11, 2023
3e66262
Various little refactors
norabelrose Apr 11, 2023
a8c21a6
Remove .predict and .predict_prob on Reporter; trying to get SciQ to …
norabelrose Apr 11, 2023
5f478b1
Bugfix for Reporter.score on binary tasks
norabelrose Apr 11, 2023
97b26ac
Fix bug where cached hidden states aren’t used when num_gpus is diffe…
norabelrose Apr 12, 2023
11fda87
Actually works now
norabelrose Apr 12, 2023
da4c72f
Refactor handling of multiple datasets
norabelrose Apr 13, 2023
e1675f7
Various fixes
norabelrose Apr 13, 2023
8cc325b
Merge remote-tracking branch 'origin/main' into multi-ds-eval
norabelrose Apr 13, 2023
14987e1
Fix math tests
norabelrose Apr 13, 2023
88683fa
Fix smoke tests
norabelrose Apr 13, 2023
a6c382e
All tests working ostensibly
norabelrose Apr 13, 2023
ecc53cb
Make CCS normalization customizable
norabelrose Apr 13, 2023
18c7f4c
log each dataset individually
AlexTMallen Apr 13, 2023
94a900c
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 13, 2023
5173649
Fix label_column bug
norabelrose Apr 13, 2023
3e6c39c
GLUE MNLI works on Deberta
norabelrose Apr 14, 2023
1e9ce06
Move pseudo AUROC stuff to CcsReporter
norabelrose Apr 14, 2023
35a8f34
Make 'datasets' and 'label_columns' config options more opinionated
norabelrose Apr 14, 2023
615bbb1
tiny spacing change
norabelrose Apr 14, 2023
f021404
Allow for toggling CV
norabelrose Apr 14, 2023
f6629ec
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 14, 2023
99f01c3
Remove duplicate dbpedia template
norabelrose Apr 14, 2023
f415f8d
Merge branch 'main' into multiclass
norabelrose Apr 14, 2023
d16c96b
Training on datasets with different numbers of classes now works
norabelrose Apr 15, 2023
044774e
Efficient bootstrap CIs for AUROCs
norabelrose Apr 15, 2023
a7f1ea0
Fix CCS smoke test failure
norabelrose Apr 15, 2023
3abeb60
Update extraction.py
lauritowal Apr 16, 2023
1e4a6b9
Merge branch 'main' into roc_auc
lauritowal Apr 16, 2023
4c60061
Update extraction.py
lauritowal Apr 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move pseudo AUROC stuff to CcsReporter
  • Loading branch information
norabelrose committed Apr 14, 2023
commit 1e9ce06bd1ca50b345ec14b5321c1aea587fc42d
12 changes: 7 additions & 5 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,21 @@ class Eval(Serializable):
`elk.training.preprocessing.normalize()` for details.
num_gpus: The number of GPUs to use. Defaults to -1, which means
"use all available GPUs".
skip_supervised: Whether to skip training the supervised classifier. Defaults to
False.
debug: When in debug mode, a useful log file is saved to the memorably-named
output directory. Defaults to False.
"""

data: Extract
source: str = field(positional=True)

concatenated_layer_offset: int = 0
debug: bool = False
out_dir: Path | None = None
num_gpus: int = -1
min_gpu_mem: int | None = None
skip_baseline: bool = False
concatenated_layer_offset: int = 0
num_gpus: int = -1
out_dir: Path | None = None
skip_supervised: bool = False

def execute(self):
datasets = self.data.prompts.datasets
Expand Down Expand Up @@ -86,7 +88,7 @@ def evaluate_reporter(
)

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
if not self.cfg.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

Expand Down
75 changes: 61 additions & 14 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch import Tensor
from torch.nn.functional import binary_cross_entropy as bce

from ..parsing import parse_loss
from ..utils.typing import assert_type
from .classifier import Classifier
from .losses import LOSSES
from .normalizer import Normalizer
from .reporter import Reporter, ReporterConfig
Expand Down Expand Up @@ -55,7 +57,6 @@ class CcsReporterConfig(ReporterConfig):
init: Literal["default", "pca", "spherical", "zero"] = "default"
loss: list[str] = field(default_factory=lambda: ["ccs"])
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
normalization: Literal["none", "meanonly", "full"] = "full"
num_layers: int = 1
pre_ln: bool = False
seed: int = 42
Expand Down Expand Up @@ -96,12 +97,8 @@ def __init__(

hidden_size = cfg.hidden_size or 4 * in_features // 3

self.neg_norm = Normalizer(
(in_features,), device=device, dtype=dtype, mode=cfg.normalization
)
self.pos_norm = Normalizer(
(in_features,), device=device, dtype=dtype, mode=cfg.normalization
)
self.neg_norm = Normalizer((in_features,), device=device, dtype=dtype)
self.pos_norm = Normalizer((in_features,), device=device, dtype=dtype)

self.probe = nn.Sequential(
nn.Linear(
Expand Down Expand Up @@ -131,6 +128,56 @@ def __init__(
)
)

def check_separability(
self,
train_pair: tuple[Tensor, Tensor],
val_pair: tuple[Tensor, Tensor],
) -> float:
"""Measure how linearly separable the pseudo-labels are for a contrast pair.

Args:
train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the
contrastive representations. Used for training the classifier.
val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the
contrastive representations. Used for evaluating the classifier.

Returns:
The AUROC of a linear classifier fit on the pseudo-labels.
"""
_x0, _x1 = train_pair
_val_x0, _val_x1 = val_pair

x0, x1 = self.neg_norm(_x0), self.pos_norm(_x1)
val_x0, val_x1 = self.neg_norm(_val_x0), self.pos_norm(_val_x1)

pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore
pseudo_train_labels = torch.cat(
[
x0.new_zeros(x0.shape[0]),
x0.new_ones(x0.shape[0]),
]
).repeat_interleave(
x0.shape[1]
) # make num_variants copies of each pseudo-label
pseudo_val_labels = torch.cat(
[
val_x0.new_zeros(val_x0.shape[0]),
val_x0.new_ones(val_x0.shape[0]),
]
).repeat_interleave(val_x0.shape[1])

pseudo_clf.fit(
# b v d -> (b v) d
torch.cat([x0, x1]).flatten(0, 1),
pseudo_train_labels,
)
with torch.no_grad():
pseudo_preds = pseudo_clf(
# b v d -> (b v) d
torch.cat([val_x0, val_x1]).flatten(0, 1)
)
return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu()))

def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
loss = sum(
LOSSES[name](logit0, logit1, coef)
Expand Down Expand Up @@ -175,9 +222,9 @@ def forward(self, x: Tensor) -> Tensor:
"""Return the raw score output of the probe on `x`."""
return self.probe(x).squeeze(-1)

def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor:
x_pos = self.pos_norm(x_pos)
def predict(self, x_neg: Tensor, x_pos: Tensor) -> Tensor:
x_neg = self.neg_norm(x_neg)
x_pos = self.pos_norm(x_pos)
return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid()))

def loss(
Expand Down Expand Up @@ -226,14 +273,14 @@ def loss(

def fit(
self,
x_pos: Tensor,
x_neg: Tensor,
x_pos: Tensor,
labels: Optional[Tensor] = None,
) -> float:
"""Fit the probe to the contrast pair (x0, x1).
"""Fit the probe to the contrast pair (neg, pos).

Args:
contrast_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the
contrast_pair: A tuple of tensors, (neg, pos), where x0 and x1 are the
contrastive representations.
labels: The labels of the contrast pair. Defaults to None.

Expand Down Expand Up @@ -280,8 +327,8 @@ def fit(

def train_loop_adam(
self,
x_pos: Tensor,
x_neg: Tensor,
x_pos: Tensor,
labels: Optional[Tensor] = None,
) -> float:
"""Adam train loop, returning the final loss. Modifies params in-place."""
Expand All @@ -302,8 +349,8 @@ def train_loop_adam(

def train_loop_lbfgs(
self,
x_pos: Tensor,
x_neg: Tensor,
x_pos: Tensor,
labels: Optional[Tensor] = None,
) -> float:
"""LBFGS train loop, returning the final loss. Modifies params in-place."""
Expand Down
49 changes: 0 additions & 49 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch import Tensor

from ..calibration import CalibrationError
from .classifier import Classifier


class EvalResult(NamedTuple):
Expand Down Expand Up @@ -68,54 +67,6 @@ class Reporter(nn.Module, ABC):
n: Tensor
config: ReporterConfig

@classmethod
def check_separability(
cls,
train_pair: tuple[Tensor, Tensor],
val_pair: tuple[Tensor, Tensor],
) -> float:
"""Measure how linearly separable the pseudo-labels are for a contrast pair.

Args:
train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the
contrastive representations. Used for training the classifier.
val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the
contrastive representations. Used for evaluating the classifier.

Returns:
The AUROC of a linear classifier fit on the pseudo-labels.
"""
x0, x1 = train_pair
val_x0, val_x1 = val_pair

pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore
pseudo_train_labels = torch.cat(
[
x0.new_zeros(x0.shape[0]),
x0.new_ones(x0.shape[0]),
]
).repeat_interleave(
x0.shape[1]
) # make num_variants copies of each pseudo-label
pseudo_val_labels = torch.cat(
[
val_x0.new_zeros(val_x0.shape[0]),
val_x0.new_ones(val_x0.shape[0]),
]
).repeat_interleave(val_x0.shape[1])

pseudo_clf.fit(
# b v d -> (b v) d
torch.cat([x0, x1]).flatten(0, 1),
pseudo_train_labels,
)
with torch.no_grad():
pseudo_preds = pseudo_clf(
# b v d -> (b v) d
torch.cat([val_x0, val_x1]).flatten(0, 1)
)
return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu()))

def reset_parameters(self):
"""Reset the parameters of the probe."""

Expand Down
76 changes: 22 additions & 54 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Optional
from typing import Callable

import pandas as pd
import torch
from einops import rearrange, repeat
from simple_parsing import Serializable, field, subgroups
from sklearn.metrics import accuracy_score, roc_auc_score

Expand All @@ -17,7 +16,6 @@
from ..utils import select_usable_devices
from ..utils.typing import assert_type
from .ccs_reporter import CcsReporter, CcsReporterConfig
from .classifier import Classifier
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .reporter import OptimConfig, ReporterConfig

Expand All @@ -34,7 +32,7 @@ class Elicit(Serializable):
"use all available GPUs".
normalization: The normalization method to use. Defaults to "meanonly". See
`elk.training.preprocessing.normalize()` for details.
skip_baseline: Whether to skip training the supervised classifier. Defaults to
skip_supervised: Whether to skip training the supervised classifier. Defaults to
False.
debug: When in debug mode, a useful log file is saved to the memorably-named
output directory. Defaults to False.
Expand All @@ -46,13 +44,12 @@ class Elicit(Serializable):
)
optim: OptimConfig = field(default_factory=OptimConfig)

num_gpus: int = -1
min_gpu_mem: int | None = None
skip_baseline: bool = False
concatenated_layer_offset: int = 0
# if nonzero, appends the hidden states of layer concatenated_layer_offset before
debug: bool = False
out_dir: Optional[Path] = None
min_gpu_mem: int | None = None
num_gpus: int = -1
out_dir: Path | None = None
skip_supervised: bool = False

def execute(self):
train_run = Train(cfg=self, out_dir=self.out_dir)
Expand Down Expand Up @@ -89,7 +86,6 @@ def train_reporter(
# Can't figure out a way to make this line less ugly
hidden_size = next(iter(train_dict.values()))[0].shape[-1]
reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
pseudo_clf = self.get_pseudo_classifier(train_dict, device)

if isinstance(self.cfg.net, CcsReporterConfig):
assert len(train_dict) == 1, "CCS only supports single-task training"
Expand All @@ -98,13 +94,19 @@ def train_reporter(
(x0, x1, labels, _) = next(iter(train_dict.values()))
train_loss = reporter.fit(x0, x1, labels)

(val_x0, val_x1, val_gt, _) = next(iter(val_dict.values()))
pseudo_auroc = reporter.check_separability(
train_pair=(x0, x1), val_pair=(val_x0, val_x1)
)

elif isinstance(self.cfg.net, EigenReporterConfig):
# To enable training on multiple tasks with different numbers of variants,
# we update the statistics in a streaming fashion and then fit
reporter = EigenReporter(hidden_size, self.cfg.net, device=device)
for ds_name, (x0, x1, labels, _) in train_dict.items():
reporter.update(x0, x1)

pseudo_auroc = None
train_loss = reporter.fit_streaming()
else:
raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}")
Expand All @@ -114,9 +116,12 @@ def train_reporter(
torch.save(reporter, file)

# Fit supervised logistic regression model
lr_model = train_supervised(train_dict, device=device)
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
torch.save(lr_model, file)
if not self.cfg.skip_supervised:
lr_model = train_supervised(train_dict, device=device)
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
torch.save(lr_model, file)
else:
lr_model = None

row_buf = []
for ds_name, (val_x0, val_x1, val_gt, val_lm_preds) in val_dict.items():
Expand All @@ -125,23 +130,6 @@ def train_reporter(
val_x0,
val_x1,
)
with torch.no_grad():
(n, v, d) = val_x0.shape

pseudo_preds = pseudo_clf(
# b v d -> (b v) d
torch.cat([val_x0, val_x1]).flatten(0, 1)
)
pseudo_labels = torch.cat(
[
val_x0.new_zeros(n),
val_x0.new_ones(n),
]
)
pseudo_labels = repeat(pseudo_labels, "n -> (n v)", v=v)
pseudo_auroc = float(
roc_auc_score(pseudo_labels.cpu(), pseudo_preds.cpu())
)

if val_lm_preds is not None:
val_gt_cpu = (
Expand All @@ -167,35 +155,15 @@ def train_reporter(
}
)

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt)
if lr_model is not None:
row["lr_auroc"], row["lr_acc"] = evaluate_supervised(
lr_model, val_x0, val_x1, val_gt
)

row["lr_auroc"] = lr_auroc
row["lr_acc"] = lr_acc
row_buf.append(row)

return pd.DataFrame(row_buf)

def get_pseudo_classifier(self, data: dict[str, tuple], device: str) -> Classifier:
"""Check the separability of the pseudo-labels at a given layer."""

x0s, x1s = [], []
for x0, x1, _, _ in data.values():
x0s.append(rearrange(x0, "n v d -> (n v) d"))
x1s.append(rearrange(x1, "n v d -> (n v) d"))

# Simple de-meaning normalization
X0 = torch.cat(x0s)
X1 = torch.cat(x1s)
X0 -= X0.mean(dim=0)
X1 -= X1.mean(dim=0)

X = torch.cat([X0, X1])
Y = torch.cat([X0.new_zeros(X0.shape[0]), X0.new_ones(X1.shape[0])])

pseudo_clf = Classifier(X.shape[-1], device=device)
pseudo_clf.fit(X, Y)
return pseudo_clf

def train(self):
"""Train a reporter on each layer of the network."""
devices = select_usable_devices(self.cfg.num_gpus)
Expand Down