Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blazing fast bootstrap stderrs for AUROC #190

Merged
merged 58 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d292c7c
LM output evaluation for autoregressive models
norabelrose Apr 4, 2023
7ed5ccd
move to own baseline file
lauritowal Apr 4, 2023
ba1d3b2
cleanup
lauritowal Apr 4, 2023
a20d4ca
Support encoder-decoder model LM output
norabelrose Apr 5, 2023
088758e
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 5, 2023
77d7418
isort
norabelrose Apr 5, 2023
5bf63f4
Bug fixes
norabelrose Apr 5, 2023
819cfed
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
d3d9a8d
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
b89e23c
Remove test_log_csv_elements
norabelrose Apr 5, 2023
9aef842
Remove Python 3.9 support
norabelrose Apr 5, 2023
0851d4f
Add Pandas to pyproject.toml
norabelrose Apr 5, 2023
207a375
add code (contains still same device cuda error)
lauritowal Apr 5, 2023
e7efcce
fix multiple cuda error, save evals to right folder + cleanup
lauritowal Apr 7, 2023
b5fa54c
Merge branch 'main' into eval_lr
lauritowal Apr 7, 2023
4f8bdc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9ca72ba
Fix bug noticed by Waree
norabelrose Apr 7, 2023
d7e4893
Merge remote-tracking branch 'origin/eval_lr' into lm-output
norabelrose Apr 7, 2023
bcdca8a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 7, 2023
713a251
Add sanity check to load_prompts and refactor binarize
norabelrose Apr 7, 2023
0c35bc7
Changing a ton of stuff
norabelrose Apr 8, 2023
f6a762a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f547744
Revert changes to binarize
norabelrose Apr 10, 2023
ab1909f
Stupid prompt_counter bug
norabelrose Apr 10, 2023
f58290f
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f912ee6
Remove stupid second set_start_method call
norabelrose Apr 10, 2023
606dcad
Merge remote-tracking branch 'origin/lm-output' into multiclass
norabelrose Apr 10, 2023
0038792
Merge remote-tracking branch 'origin/main' into multiclass
norabelrose Apr 10, 2023
83b480b
Fix bugs in binary case
norabelrose Apr 11, 2023
3e66262
Various little refactors
norabelrose Apr 11, 2023
a8c21a6
Remove .predict and .predict_prob on Reporter; trying to get SciQ to …
norabelrose Apr 11, 2023
5f478b1
Bugfix for Reporter.score on binary tasks
norabelrose Apr 11, 2023
97b26ac
Fix bug where cached hidden states aren’t used when num_gpus is diffe…
norabelrose Apr 12, 2023
11fda87
Actually works now
norabelrose Apr 12, 2023
da4c72f
Refactor handling of multiple datasets
norabelrose Apr 13, 2023
e1675f7
Various fixes
norabelrose Apr 13, 2023
8cc325b
Merge remote-tracking branch 'origin/main' into multi-ds-eval
norabelrose Apr 13, 2023
14987e1
Fix math tests
norabelrose Apr 13, 2023
88683fa
Fix smoke tests
norabelrose Apr 13, 2023
a6c382e
All tests working ostensibly
norabelrose Apr 13, 2023
ecc53cb
Make CCS normalization customizable
norabelrose Apr 13, 2023
18c7f4c
log each dataset individually
AlexTMallen Apr 13, 2023
94a900c
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 13, 2023
5173649
Fix label_column bug
norabelrose Apr 13, 2023
3e6c39c
GLUE MNLI works on Deberta
norabelrose Apr 14, 2023
1e9ce06
Move pseudo AUROC stuff to CcsReporter
norabelrose Apr 14, 2023
35a8f34
Make 'datasets' and 'label_columns' config options more opinionated
norabelrose Apr 14, 2023
615bbb1
tiny spacing change
norabelrose Apr 14, 2023
f021404
Allow for toggling CV
norabelrose Apr 14, 2023
f6629ec
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 14, 2023
99f01c3
Remove duplicate dbpedia template
norabelrose Apr 14, 2023
f415f8d
Merge branch 'main' into multiclass
norabelrose Apr 14, 2023
d16c96b
Training on datasets with different numbers of classes now works
norabelrose Apr 15, 2023
044774e
Efficient bootstrap CIs for AUROCs
norabelrose Apr 15, 2023
a7f1ea0
Fix CCS smoke test failure
norabelrose Apr 15, 2023
3abeb60
Update extraction.py
lauritowal Apr 16, 2023
1e4a6b9
Merge branch 'main' into roc_auc
lauritowal Apr 16, 2023
4c60061
Update extraction.py
lauritowal Apr 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
LM output evaluation for autoregressive models
  • Loading branch information
norabelrose committed Apr 4, 2023
commit d292c7c96080e7960a4b95151e5206a8284e29b1
20 changes: 4 additions & 16 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
import csv
import os
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Literal, Optional, cast
from typing import Callable, Literal, Optional

import torch
import torch.multiprocessing as mp
from simple_parsing.helpers import Serializable, field
from torch import Tensor
from tqdm.auto import tqdm

from datasets import DatasetDict
from elk.evaluation.evaluate_log import EvalLog
from elk.extraction.extraction import Extract
from elk.run import Run
from elk.training import Reporter

from ..files import elk_reporter_dir, memorably_named_dir
from ..training.preprocessing import normalize
from ..utils import (
assert_type,
int16_to_float32,
select_train_val_splits,
select_usable_devices,
)
from ..files import elk_reporter_dir
from ..utils import select_usable_devices


@dataclass
Expand Down Expand Up @@ -71,7 +59,7 @@ def evaluate_reporter(
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, _, test_x0, test_x1, _, test_labels = self.prepare_data(
_, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data(
device,
layer,
)
Expand Down
2 changes: 0 additions & 2 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from ..utils import infer_label_column
from ..utils.typing import assert_type
from collections import deque
from dataclasses import dataclass
from datasets import IterableDataset, Features
from itertools import cycle
from random import Random
Expand All @@ -13,7 +12,6 @@
class BalancedSampler(TorchIterableDataset):
"""
Approximately balances a binary classification dataset in a streaming fashion.
Written mostly by GPT-4.

Args:
dataset (IterableDataset): The HuggingFace IterableDataset to balance.
Expand Down
68 changes: 48 additions & 20 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
"""Functions for extracting the hidden states of a model."""
import logging
import os
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Iterable, Literal, Optional, Union

import torch
from simple_parsing import Serializable, field
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel

from datasets import (
Array2D,
Array3D,
ClassLabel,
DatasetDict,
Expand All @@ -20,11 +12,19 @@
Value,
get_dataset_config_info,
)
from elk.utils.typing import float32_to_int16
from itertools import islice
from simple_parsing import Serializable, field
from transformers import AutoConfig, AutoTokenizer, PreTrainedModel
from typing import Iterable, Literal, Optional, Union
import logging
import os
import torch

from ..utils import (
assert_type,
infer_label_column,
convert_span,
float32_to_int16,
get_model_class,
select_train_val_splits,
select_usable_devices,
)
Expand Down Expand Up @@ -101,10 +101,12 @@ def extract_hiddens(
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
model = AutoModel.from_pretrained(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
model_cls = get_model_class(cfg.model)
model = assert_type(
PreTrainedModel,
model_cls.from_pretrained(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
),
).to(device)
# TODO: Maybe also make this configurable?
# We want to make sure the answer is never truncated
Expand All @@ -126,7 +128,6 @@ def extract_hiddens(

# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
# print(f"Using {prompt_ds} variants for each dataset")

global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
Expand All @@ -135,8 +136,6 @@ def extract_hiddens(
if rank == world_size - 1:
max_examples += global_max_examples % world_size

print(f"Extracting {max_examples} examples from {prompt_ds} on {device}")

for example in islice(BalancedSampler(prompt_ds), max_examples):
num_variants = len(example["prompts"])
hidden_dict = {
Expand All @@ -149,24 +148,47 @@ def extract_hiddens(
)
for layer_idx in layer_indices
}
model_preds = torch.empty(
num_variants,
2, # contrast pair
device=device,
dtype=torch.float32,
)
text_inputs = []

# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_inputs = []

# Iterate over answers
for j in range(2):
text = record[j]["text"]
for j, choice in enumerate(record):
text = choice["text"]
variant_inputs.append(text)

inputs = tokenizer(
text,
return_offsets_mapping=True,
return_tensors="pt",
truncation=True,
).to(device)

# The offset_mapping is a sorted list of (start, end) tuples. We locate
# the start of the answer in the tokenized sequence with binary search.
offsets = inputs.pop("offset_mapping").squeeze().tolist()

outputs = model(**inputs, output_hidden_states=True)

# TODO: Do something smarter than "rindex" here. Really we'd like to
# get the span of the answer directly from Jinja, but that doesn't seem
# to be supported. The current approach may fail for complex templates.
answer_start = text.rindex(choice["answer"])
start, end = convert_span(
offsets, (answer_start, answer_start + len(choice["answer"]))
)
log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(dim=-1)
tokens = inputs.input_ids[..., start:end, None]
model_preds[i, j] = log_p.gather(-1, tokens).sum()

hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
Expand All @@ -193,6 +215,8 @@ def extract_hiddens(

yield dict(
label=example["label"],
# We only need the probability of the positive example since this is binary
model_preds=model_preds.softmax(dim=-1)[..., 1],
variant_ids=example["template_names"],
text_inputs=text_inputs,
**hidden_dict,
Expand Down Expand Up @@ -245,6 +269,10 @@ def get_splits() -> SplitDict:
length=num_variants,
),
"label": ClassLabel(names=["neg", "pos"]),
"model_preds": Sequence(
Value(dtype="float32"),
length=num_variants,
),
"text_inputs": Sequence(
Sequence(
Value(dtype="string"),
Expand Down
5 changes: 4 additions & 1 deletion elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def prepare_data(
x0, x1 = train_h.unbind(dim=-2)
val_x0, val_x1 = val_h.unbind(dim=-2)

return x0, x1, val_x0, val_x1, train_labels, val_labels
with self.dataset.formatted_as("numpy"):
val_lm_preds = assert_type(np.ndarray, val["model_preds"])

return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds

def concatenate(self, layers):
"""Concatenate hidden states from a previous layer."""
Expand Down
16 changes: 10 additions & 6 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def train_reporter(

device = self.get_device(devices, world_size)

x0, x1, val_x0, val_x1, train_labels, val_labels = self.prepare_data(
x0, x1, val_x0, val_x1, train_gt, val_gt, val_lm_preds = self.prepare_data(
device, layer
)
pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1)
Expand All @@ -136,19 +136,23 @@ def train_reporter(
else:
raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}")

train_loss = reporter.fit(x0, x1, train_labels)
train_loss = reporter.fit(x0, x1, train_gt)
val_result = reporter.score(
val_labels,
val_gt,
val_x0,
val_x1,
)

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
stats: ElicitLog = ElicitLog(
val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu()

stats = ElicitLog(
layer=layer,
pseudo_auroc=pseudo_auroc,
train_loss=train_loss,
eval_result=val_result,
lm_auroc=float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())),
lm_acc=float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)),
)

if not self.cfg.skip_baseline:
Expand All @@ -157,8 +161,8 @@ def train_reporter(
x1,
val_x0,
val_x1,
train_labels,
val_labels,
train_gt,
val_gt,
device,
)
stats.lr_auroc = lr_auroc
Expand Down
13 changes: 10 additions & 3 deletions elk/training/train_log.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from .reporter import EvalResult
from dataclasses import dataclass
from typing import Optional

from elk.training.reporter import EvalResult


@dataclass
class ElicitLog:
"""The result of running elicit on a layer of a dataset"""

layer: int
pseudo_auroc: float
train_loss: float
eval_result: EvalResult
pseudo_auroc: float

lm_auroc: float
lm_acc: float

# Only available if reporting baseline
lr_auroc: Optional[float] = None
# Only available if reporting baseline
Expand All @@ -28,6 +31,8 @@ def csv_columns(skip_baseline: bool) -> list[str]:
"cal_acc",
"auroc",
"ece",
"lm_auroc",
"lm_acc",
]
if not skip_baseline:
cols += ["lr_auroc", "lr_acc"]
Expand All @@ -43,6 +48,8 @@ def to_csv_line(self, skip_baseline: bool) -> list[str]:
self.eval_result.cal_acc,
self.eval_result.auroc,
self.eval_result.ece,
self.lm_auroc,
self.lm_acc,
]
if not skip_baseline:
items += [self.lr_auroc, self.lr_acc]
Expand Down
2 changes: 2 additions & 0 deletions elk/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .data_utils import (
binarize,
convert_span,
get_columns_all_equal,
infer_label_column,
infer_num_classes,
select_train_val_splits,
)

from .gpu_utils import select_usable_devices
from .hf_utils import get_model_class
from .tree_utils import pytree_map
from .typing import assert_type, float32_to_int16, int16_to_float32
24 changes: 21 additions & 3 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
from .typing import assert_type
from ..promptsource.templates import Template
from bisect import bisect_left, bisect_right
from datasets import (
ClassLabel,
DatasetDict,
Features,
Split,
Value,
)
from operator import itemgetter
from random import Random
import torch
from typing import Iterable, Optional, List, Any
import numpy as np
from typing import Iterable, List, Any
import copy


def convert_span(
offsets: list[tuple[int, int]], span: tuple[int, int]
) -> tuple[int, int]:
"""Convert `span` from string coordinates to token coordinates.

Args:
offsets: The offset mapping of the target tokenization.
span: The span to convert.

Returns:
(start, end): The converted span.
"""
start, end = span
start = bisect_right(offsets, start, key=itemgetter(1))
end = bisect_left(offsets, end, lo=start, key=itemgetter(0))
return start, end


def get_columns_all_equal(dataset: DatasetDict) -> list[str]:
"""Get columns of a `DatasetDict`, asserting all splits have the same columns."""
pivot, *rest = dataset.column_names.values()
Expand Down
32 changes: 32 additions & 0 deletions elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .typing import assert_type
from transformers import AutoConfig, PreTrainedModel
from typing import Type
import transformers


def get_model_class(model_str: str) -> Type[PreTrainedModel]:
"""Get the appropriate model class for a model string."""
model_cfg = AutoConfig.from_pretrained(model_str)
archs = assert_type(list, model_cfg.architectures)

# Ordered by preference
suffixes = [
# Fine-tuned for classification
"SequenceClassification",
# Encoder-decoder models
"ConditionalGeneration",
# Autoregressive models
"CausalLM",
"LMHeadModel",
]

for suffix in suffixes:
# Check if any of the architectures in the config end with the suffix.
# If so, return the corresponding model class.
for arch_str in archs:
if arch_str.endswith(suffix):
return getattr(transformers, arch_str)

raise ValueError(
f"'{model_str}' does not have any supported architectures: {archs}"
)
Loading