Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blazing fast bootstrap stderrs for AUROC #190

Merged
merged 58 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d292c7c
LM output evaluation for autoregressive models
norabelrose Apr 4, 2023
7ed5ccd
move to own baseline file
lauritowal Apr 4, 2023
ba1d3b2
cleanup
lauritowal Apr 4, 2023
a20d4ca
Support encoder-decoder model LM output
norabelrose Apr 5, 2023
088758e
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 5, 2023
77d7418
isort
norabelrose Apr 5, 2023
5bf63f4
Bug fixes
norabelrose Apr 5, 2023
819cfed
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
d3d9a8d
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
b89e23c
Remove test_log_csv_elements
norabelrose Apr 5, 2023
9aef842
Remove Python 3.9 support
norabelrose Apr 5, 2023
0851d4f
Add Pandas to pyproject.toml
norabelrose Apr 5, 2023
207a375
add code (contains still same device cuda error)
lauritowal Apr 5, 2023
e7efcce
fix multiple cuda error, save evals to right folder + cleanup
lauritowal Apr 7, 2023
b5fa54c
Merge branch 'main' into eval_lr
lauritowal Apr 7, 2023
4f8bdc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9ca72ba
Fix bug noticed by Waree
norabelrose Apr 7, 2023
d7e4893
Merge remote-tracking branch 'origin/eval_lr' into lm-output
norabelrose Apr 7, 2023
bcdca8a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 7, 2023
713a251
Add sanity check to load_prompts and refactor binarize
norabelrose Apr 7, 2023
0c35bc7
Changing a ton of stuff
norabelrose Apr 8, 2023
f6a762a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f547744
Revert changes to binarize
norabelrose Apr 10, 2023
ab1909f
Stupid prompt_counter bug
norabelrose Apr 10, 2023
f58290f
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f912ee6
Remove stupid second set_start_method call
norabelrose Apr 10, 2023
606dcad
Merge remote-tracking branch 'origin/lm-output' into multiclass
norabelrose Apr 10, 2023
0038792
Merge remote-tracking branch 'origin/main' into multiclass
norabelrose Apr 10, 2023
83b480b
Fix bugs in binary case
norabelrose Apr 11, 2023
3e66262
Various little refactors
norabelrose Apr 11, 2023
a8c21a6
Remove .predict and .predict_prob on Reporter; trying to get SciQ to …
norabelrose Apr 11, 2023
5f478b1
Bugfix for Reporter.score on binary tasks
norabelrose Apr 11, 2023
97b26ac
Fix bug where cached hidden states aren’t used when num_gpus is diffe…
norabelrose Apr 12, 2023
11fda87
Actually works now
norabelrose Apr 12, 2023
da4c72f
Refactor handling of multiple datasets
norabelrose Apr 13, 2023
e1675f7
Various fixes
norabelrose Apr 13, 2023
8cc325b
Merge remote-tracking branch 'origin/main' into multi-ds-eval
norabelrose Apr 13, 2023
14987e1
Fix math tests
norabelrose Apr 13, 2023
88683fa
Fix smoke tests
norabelrose Apr 13, 2023
a6c382e
All tests working ostensibly
norabelrose Apr 13, 2023
ecc53cb
Make CCS normalization customizable
norabelrose Apr 13, 2023
18c7f4c
log each dataset individually
AlexTMallen Apr 13, 2023
94a900c
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 13, 2023
5173649
Fix label_column bug
norabelrose Apr 13, 2023
3e6c39c
GLUE MNLI works on Deberta
norabelrose Apr 14, 2023
1e9ce06
Move pseudo AUROC stuff to CcsReporter
norabelrose Apr 14, 2023
35a8f34
Make 'datasets' and 'label_columns' config options more opinionated
norabelrose Apr 14, 2023
615bbb1
tiny spacing change
norabelrose Apr 14, 2023
f021404
Allow for toggling CV
norabelrose Apr 14, 2023
f6629ec
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 14, 2023
99f01c3
Remove duplicate dbpedia template
norabelrose Apr 14, 2023
f415f8d
Merge branch 'main' into multiclass
norabelrose Apr 14, 2023
d16c96b
Training on datasets with different numbers of classes now works
norabelrose Apr 15, 2023
044774e
Efficient bootstrap CIs for AUROCs
norabelrose Apr 15, 2023
a7f1ea0
Fix CCS smoke test failure
norabelrose Apr 15, 2023
3abeb60
Update extraction.py
lauritowal Apr 16, 2023
1e4a6b9
Merge branch 'main' into roc_auc
lauritowal Apr 16, 2023
4c60061
Update extraction.py
lauritowal Apr 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions elk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
from simple_parsing import ArgumentParser

from elk.evaluation.evaluate import Eval
from elk.extraction.extraction import Extract
from elk.training.train import Elicit


@dataclass
class Command:
"""Some top-level command"""

command: Elicit | Eval | Extract
command: Elicit | Eval

def execute(self):
return self.command.execute()


def run():
parser = ArgumentParser(add_help=False)
parser = ArgumentParser(add_help=False, add_config_path_arg=True)
parser.add_arguments(Command, dest="run")
args = parser.parse_args()
run: Command = args.run
Expand Down
56 changes: 27 additions & 29 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Literal, Optional
from typing import Callable

import pandas as pd
import torch
Expand All @@ -11,7 +11,7 @@
from ..files import elk_reporter_dir
from ..run import Run
from ..training import Reporter
from ..training.baseline import evaluate_baseline, load_baseline
from ..training.supervised import evaluate_supervised
from ..utils import select_usable_devices


Expand All @@ -34,13 +34,12 @@ class Eval(Serializable):

data: Extract
source: str = field(positional=True)
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"

debug: bool = False
out_dir: Optional[Path] = None
out_dir: Path | None = None
num_gpus: int = -1
min_gpu_mem: int | None = None
skip_baseline: bool = False
concatenated_layer_offset: int = 0

def execute(self):
datasets = self.data.prompts.datasets
Expand All @@ -58,50 +57,49 @@ class Evaluate(Run):

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> pd.Series:
) -> pd.DataFrame:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, test_h, _, test_labels, _ = self.prepare_data(
device,
layer,
)
val_output = self.prepare_data(device, layer, "val")

experiment_dir = elk_reporter_dir() / self.cfg.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

test_result = reporter.score(test_labels, test_h)
stats_row = pd.Series(
{
"layer": layer,
**test_result._asdict(),
}
)
row_buf = []
for ds_name, (val_h, val_gt, _) in val_output.items():
val_result = reporter.score(val_gt, val_h)

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
lr_model = load_baseline(lr_dir, layer)
lr_model.eval()
lr_auroc, lr_acc = evaluate_baseline(
lr_model.cuda(), test_h.cuda(), test_labels
stats_row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
**val_result._asdict(),
}
)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc
lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc

return stats_row
return pd.DataFrame(row_buf)

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem
self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], pd.Series] = partial(
func: Callable[[int], pd.DataFrame] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)
2 changes: 1 addition & 1 deletion elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from datasets import Features, IterableDataset
from torch.utils.data import IterableDataset as TorchIterableDataset

from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.math_util import stochastic_round_constrained
from ..utils.typing import assert_type


Expand Down
55 changes: 34 additions & 21 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal, Optional
from typing import Any, Iterable, Literal

import torch
from datasets import (
Expand All @@ -23,6 +23,7 @@
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
convert_span,
Expand All @@ -49,7 +50,6 @@ class Extract(Serializable):
layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`.
token_loc: The location of the token to extract hidden states from. Can be
either "first", "last", or "mean". Defaults to "last".
min_gpu_mem: Minimum amount of free memory (in bytes) required to select a GPU.
"""

prompts: PromptConfig
Expand All @@ -58,8 +58,6 @@ class Extract(Serializable):
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
min_gpu_mem: Optional[int] = None
num_gpus: int = -1

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand All @@ -75,8 +73,16 @@ def __post_init__(self, layer_stride: int):
)
self.layers = tuple(range(0, config.num_hidden_layers, layer_stride))

def execute(self):
extract(cfg=self, num_gpus=self.num_gpus)
def explode(self) -> list["Extract"]:
"""Explode this config into a list of configs, one for each layer."""
copies = []

for prompt_cfg in self.prompts.explode():
cfg = copy(self)
cfg.prompts = prompt_cfg
copies.append(cfg)

return copies


@torch.no_grad()
Expand All @@ -95,9 +101,12 @@ def extract_hiddens(
if rank != 0:
logging.disable(logging.CRITICAL)

ds_names = cfg.prompts.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

prompt_ds = load_prompts(
*cfg.prompts.datasets,
label_column=cfg.prompts.label_column,
ds_names[0],
label_column=cfg.prompts.label_columns[0],
num_classes=cfg.prompts.num_classes,
split_type=split_type,
stream=cfg.prompts.stream,
Expand Down Expand Up @@ -244,14 +253,19 @@ def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})


def extract(cfg: "Extract", num_gpus: int = -1) -> DatasetDict:
def extract(
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

def get_splits() -> SplitDict:
available_splits = assert_type(SplitDict, info.splits)
train_name, val_name = select_train_val_splits(available_splits)
print(f"Using '{train_name}' for training and '{val_name}' for validation")

print(
# Cyan color for dataset name
f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and"
f" '{val_name}' for validation"
)
limit_list = cfg.prompts.max_examples

return SplitDict(
Expand All @@ -267,14 +281,17 @@ def get_splits() -> SplitDict:
)

model_cfg = AutoConfig.from_pretrained(cfg.model)
num_variants = cfg.prompts.num_variants

ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

ds_features = assert_type(Features, info.features)
label_col = cfg.prompts.label_column or infer_label_column(ds_features)
label_col = cfg.prompts.label_columns[0] or infer_label_column(ds_features)
num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col])
num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
num_variants = len(prompter.templates)

layer_cols = {
f"hidden_{layer}": Array3D(
Expand Down Expand Up @@ -304,22 +321,18 @@ def get_splits() -> SplitDict:
dtype="float32",
)

devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem)

# Prevent the GPU-related config options from invalidating the cache
_cfg = copy(cfg)
_cfg.min_gpu_mem = None
_cfg.num_gpus = -1

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
builders = {
split_name: _GeneratorBuilder(
builder_name=info.builder_name,
config_name=info.config_name,
cache_dir=None,
features=Features({**layer_cols, **other_cols}),
generator=_extraction_worker,
split_name=split_name,
split_info=split_info,
gen_kwargs=dict(
cfg=[_cfg] * len(devices),
cfg=[cfg] * len(devices),
device=devices,
rank=list(range(len(devices))),
split_type=[split_name] * len(devices),
Expand Down
40 changes: 29 additions & 11 deletions elk/extraction/generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Optional

import datasets
from datasets import Features
from typing import Any, Callable

from datasets import (
BuilderConfig,
DatasetInfo,
Features,
GeneratorBasedBuilder,
SplitInfo,
)
from datasets.splits import NamedSplit


@dataclass
class _GeneratorConfig(datasets.BuilderConfig):
generator: Optional[Callable] = None
class _GeneratorConfig(BuilderConfig):
generator: Callable | None = None
gen_kwargs: dict[str, Any] = field(default_factory=dict)
features: Optional[datasets.Features] = None
features: Features | None = None

def create_config_id(
self, config_kwargs: dict, custom_features: Features | None
Expand All @@ -37,28 +42,41 @@ class _SplitGenerator:
"""

name: str
split_info: datasets.SplitInfo
split_info: SplitInfo
gen_kwargs: dict = field(default_factory=dict)

def __post_init__(self):
self.name = str(self.name) # Make sure we convert NamedSplits in strings
NamedSplit(self.name) # check that it's a valid split name


class _GeneratorBuilder(datasets.GeneratorBasedBuilder):
class _GeneratorBuilder(GeneratorBasedBuilder):
"""Patched version of `datasets.Generator` allowing for splits besides `train`"""

BUILDER_CONFIG_CLASS = _GeneratorConfig
config: _GeneratorConfig

def __init__(self, split_name: str, split_info: datasets.SplitInfo, **kwargs):
def __init__(
self,
builder_name: str | None,
config_name: str | None,
split_name: str,
split_info: SplitInfo,
**kwargs,
):
self.split_name = split_name
self.split_info = split_info

super().__init__(**kwargs)

# Weirdly we need to set DatasetInfo.builder_name and DatasetInfo.config_name
# here, not in _info, because super().__init__ modifies them
self.info.builder_name = builder_name
self.info.config_name = config_name

def _info(self):
return datasets.DatasetInfo(features=self.config.features)
# Use the same builder and config name as the original builder
return DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
return [
Expand Down
Loading