Skip to content

Commit

Permalink
Multiple datasets refactor (EleutherAI#189)
Browse files Browse the repository at this point in the history
* Fix bug where cached hidden states aren’t used when num_gpus is different

* Actually works now

* Refactor handling of multiple datasets

* Various fixes

* Fix math tests

* Fix smoke tests

* All tests working ostensibly

* Make CCS normalization customizable

* log each dataset individually

* Move pseudo AUROC stuff to CcsReporter

* Make 'datasets' and 'label_columns' config options more opinionated

* tiny spacing change

* Allow for toggling CV

* add typing to logging; rename logging

* Fix eval logging bug

---------

Co-authored-by: Alex Mallen <[email protected]>
  • Loading branch information
norabelrose and AlexTMallen committed Apr 14, 2023
1 parent 361fb9b commit 16dc1ca
Show file tree
Hide file tree
Showing 28 changed files with 687 additions and 627 deletions.
5 changes: 2 additions & 3 deletions elk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
from simple_parsing import ArgumentParser

from elk.evaluation.evaluate import Eval
from elk.extraction.extraction import Extract
from elk.training.train import Elicit


@dataclass
class Command:
"""Some top-level command"""

command: Elicit | Eval | Extract
command: Elicit | Eval

def execute(self):
return self.command.execute()


def run():
parser = ArgumentParser(add_help=False)
parser = ArgumentParser(add_help=False, add_config_path_arg=True)
parser.add_arguments(Command, dest="run")
args = parser.parse_args()
run: Command = args.run
Expand Down
57 changes: 57 additions & 0 deletions elk/debug_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
from pathlib import Path

from datasets import DatasetDict

from .utils import get_dataset_name, select_train_val_splits


def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None:
"""
Save a debug log to the output directory. This is useful for debugging
training issues.
"""

logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s:\n%(message)s",
filename=out_dir / "debug.log",
filemode="w",
)

for ds in datasets:
logging.info(
"=========================================\n"
f"Dataset: {get_dataset_name(ds)}\n"
"========================================="
)

train_split, val_split = select_train_val_splits(ds)
text_inputs = ds[val_split][0]["text_inputs"]
template_ids = ds[val_split][0]["variant_ids"]
label = ds[val_split][0]["label"]

# log the train size and val size
logging.info(f"Train size: {len(ds[train_split])}")
logging.info(f"Val size: {len(ds[val_split])}")

templates_text = f"{len(text_inputs)} templates used:\n"
trailing_whitespace = False
for (text0, text1), id in zip(text_inputs, template_ids):
templates_text += (
f'***---TEMPLATE "{id}"---***\n'
f"{'false' if label else 'true'}:\n"
f'"""{text0}"""\n'
f"{'true' if label else 'false'}:\n"
f'"""{text1}"""\n\n\n'
)
if text0[-1].isspace() or text1[-1].isspace():
trailing_whitespace = True
if trailing_whitespace:
logging.warning(
"Some inputs to the model have trailing whitespace! "
"Check that the jinja templates are not adding "
"trailing whitespace. If `token_loc` is 'last', this "
"will extract hidden states from the whitespace token."
)
logging.info(templates_text)
75 changes: 37 additions & 38 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Literal, Optional
from typing import Callable

import pandas as pd
import torch
Expand All @@ -11,7 +11,7 @@
from ..files import elk_reporter_dir
from ..run import Run
from ..training import Reporter
from ..training.baseline import evaluate_baseline, load_baseline
from ..training.supervised import evaluate_supervised
from ..utils import select_usable_devices


Expand All @@ -28,26 +28,25 @@ class Eval(Serializable):
`elk.training.preprocessing.normalize()` for details.
num_gpus: The number of GPUs to use. Defaults to -1, which means
"use all available GPUs".
skip_supervised: Whether to skip evaluation of the supervised classifier.
debug: When in debug mode, a useful log file is saved to the memorably-named
output directory. Defaults to False.
"""

data: Extract
source: str = field(positional=True)
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"

concatenated_layer_offset: int = 0
debug: bool = False
out_dir: Optional[Path] = None
min_gpu_mem: int | None = None
num_gpus: int = -1
skip_baseline: bool = False
concatenated_layer_offset: int = 0
out_dir: Path | None = None
skip_supervised: bool = False

def execute(self):
datasets = self.data.prompts.datasets

transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"

for dataset in datasets:
for dataset in self.data.prompts.datasets:
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()

Expand All @@ -58,55 +57,55 @@ class Evaluate(Run):

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> pd.Series:
) -> pd.DataFrame:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data(
device,
layer,
)
val_output = self.prepare_data(device, layer, "val")

experiment_dir = elk_reporter_dir() / self.cfg.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

test_result = reporter.score(
test_labels,
test_x0,
test_x1,
)

stats_row = pd.Series(
{
"layer": layer,
**test_result._asdict(),
}
)
row_buf = []
for ds_name, (val_x0, val_x1, val_gt, _) in val_output.items():
val_result = reporter.score(
val_gt,
val_x0,
val_x1,
)

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
lr_model = load_baseline(lr_dir, layer)
lr_model.eval()
lr_auroc, lr_acc = evaluate_baseline(
lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels
stats_row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
**val_result._asdict(),
}
)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc
lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc

row_buf.append(stats_row)

return stats_row
return pd.DataFrame(row_buf)

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem
self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], pd.Series] = partial(
func: Callable[[int], pd.DataFrame] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)
2 changes: 1 addition & 1 deletion elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from datasets import Features, IterableDataset
from torch.utils.data import IterableDataset as TorchIterableDataset

from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.math_util import stochastic_round_constrained
from ..utils.typing import assert_type


Expand Down
52 changes: 33 additions & 19 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal, Optional
from typing import Any, Iterable, Literal

import torch
from datasets import (
Expand All @@ -23,6 +23,7 @@
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
convert_span,
Expand All @@ -48,7 +49,6 @@ class Extract(Serializable):
layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`.
token_loc: The location of the token to extract hidden states from. Can be
either "first", "last", or "mean". Defaults to "last".
min_gpu_mem: Minimum amount of free memory (in bytes) required to select a GPU.
"""

prompts: PromptConfig
Expand All @@ -57,8 +57,6 @@ class Extract(Serializable):
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
min_gpu_mem: Optional[int] = None
num_gpus: int = -1

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand All @@ -74,8 +72,16 @@ def __post_init__(self, layer_stride: int):
)
self.layers = tuple(range(0, config.num_hidden_layers, layer_stride))

def execute(self):
extract(cfg=self, num_gpus=self.num_gpus)
def explode(self) -> list["Extract"]:
"""Explode this config into a list of configs, one for each layer."""
copies = []

for prompt_cfg in self.prompts.explode():
cfg = copy(self)
cfg.prompts = prompt_cfg
copies.append(cfg)

return copies


@torch.no_grad()
Expand All @@ -94,8 +100,11 @@ def extract_hiddens(
if rank != 0:
logging.disable(logging.CRITICAL)

ds_names = cfg.prompts.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

prompt_ds = load_prompts(
*cfg.prompts.datasets,
ds_names[0],
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
Expand Down Expand Up @@ -240,14 +249,19 @@ def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})


def extract(cfg: "Extract", num_gpus: int = -1) -> DatasetDict:
def extract(
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

def get_splits() -> SplitDict:
available_splits = assert_type(SplitDict, info.splits)
train_name, val_name = select_train_val_splits(available_splits)
print(f"Using '{train_name}' for training and '{val_name}' for validation")

print(
# Cyan color for dataset name
f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and"
f" '{val_name}' for validation"
)
limit_list = cfg.prompts.max_examples

return SplitDict(
Expand All @@ -263,11 +277,15 @@ def get_splits() -> SplitDict:
)

model_cfg = AutoConfig.from_pretrained(cfg.model)
num_variants = cfg.prompts.num_variants

ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
num_variants = len(prompter.templates)

layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
Expand Down Expand Up @@ -297,22 +315,18 @@ def get_splits() -> SplitDict:
length=num_variants,
)

devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem)

# Prevent the GPU-related config options from invalidating the cache
_cfg = copy(cfg)
_cfg.min_gpu_mem = None
_cfg.num_gpus = -1

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
builders = {
split_name: _GeneratorBuilder(
builder_name=info.builder_name,
config_name=info.config_name,
cache_dir=None,
features=Features({**layer_cols, **other_cols}),
generator=_extraction_worker,
split_name=split_name,
split_info=split_info,
gen_kwargs=dict(
cfg=[_cfg] * len(devices),
cfg=[cfg] * len(devices),
device=devices,
rank=list(range(len(devices))),
split_type=[split_name] * len(devices),
Expand Down

0 comments on commit 16dc1ca

Please sign in to comment.