Skip to content

Commit

Permalink
Store final layer LM output and record AUROC and acc (EleutherAI#165)
Browse files Browse the repository at this point in the history
* LM output evaluation for autoregressive models

* move to own baseline file

* cleanup

* Support encoder-decoder model LM output

* isort

* Bug fixes

* Remove test_log_csv_elements

* Remove Python 3.9 support

* Add Pandas to pyproject.toml

* add code (contains still same device cuda error)

* fix multiple cuda error, save evals to right folder + cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix bug noticed by Waree

* Add sanity check to load_prompts and refactor binarize

* Revert changes to binarize

* Stupid prompt_counter bug

* Remove stupid second set_start_method call

---------

Co-authored-by: Walter Laurito <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 10, 2023
1 parent 99fe004 commit 14a7c2a
Show file tree
Hide file tree
Showing 20 changed files with 362 additions and 517 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
run-tests:
strategy:
matrix:
python-versions: [ 3.9, "3.10", "3.11" ]
python-versions: [ "3.10", "3.11" ]
os: [ ubuntu-latest, macos-latest ]
runs-on: ${{ matrix.os }}
steps:
Expand Down
3 changes: 1 addition & 2 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Main entry point for `elk`."""

from dataclasses import dataclass
from typing import Union

from simple_parsing import ArgumentParser

Expand All @@ -14,7 +13,7 @@
class Command:
"""Some top-level command"""

command: Union[Elicit, Eval, Extract]
command: Elicit | Eval | Extract

def execute(self):
return self.command.execute()
Expand Down
56 changes: 32 additions & 24 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
from pathlib import Path
from typing import Callable, Literal, Optional

import pandas as pd
import torch
from simple_parsing.helpers import Serializable, field

from elk.evaluation.evaluate_log import EvalLog
from elk.extraction.extraction import Extract
from elk.run import Run
from elk.training import Reporter

from ..extraction.extraction import Extract
from ..files import elk_reporter_dir
from ..utils import (
select_usable_devices,
)
from ..run import Run
from ..training import Reporter
from ..training.baseline import evaluate_baseline, load_baseline
from ..utils import select_usable_devices


@dataclass
Expand All @@ -41,7 +39,7 @@ class Eval(Serializable):
debug: bool = False
out_dir: Optional[Path] = None
num_gpus: int = -1

skip_baseline: bool = False
concatenated_layer_offset: int = 0

def execute(self):
Expand All @@ -60,18 +58,18 @@ class Evaluate(Run):

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> EvalLog:
) -> pd.Series:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, _, test_x0, test_x1, _, test_labels = self.prepare_data(
_, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data(
device,
layer,
)

reporter_path = (
elk_reporter_dir() / self.cfg.source / "reporters" / f"layer_{layer}.pt"
)
experiment_dir = elk_reporter_dir() / self.cfg.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

Expand All @@ -81,24 +79,34 @@ def evaluate_reporter(
test_x1,
)

return EvalLog(
layer=layer,
eval_result=test_result,
stats_row = pd.Series(
{
"layer": layer,
**test_result._asdict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
lr_model = load_baseline(lr_dir, layer)
lr_model.eval()
lr_auroc, lr_acc = evaluate_baseline(
lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels
)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc

return stats_row

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], EvalLog] = partial(
func: Callable[[int], pd.Series] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(
func=func,
num_devices=num_devices,
to_csv_line=lambda item: item.to_csv_line(),
csv_columns=EvalLog.csv_columns(),
)
self.apply_to_layers(func=func, num_devices=num_devices)
27 changes: 0 additions & 27 deletions elk/evaluation/evaluate_log.py

This file was deleted.

1 change: 0 additions & 1 deletion elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
class BalancedSampler(TorchIterableDataset):
"""
Approximately balances a binary classification dataset in a streaming fashion.
Written mostly by GPT-4.
Args:
dataset (IterableDataset): The HuggingFace IterableDataset to balance.
Expand Down
108 changes: 78 additions & 30 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Iterable, Literal, Optional, Union
from typing import Any, Iterable, Literal, Optional

import torch
from datasets import (
Expand All @@ -18,12 +18,17 @@
get_dataset_config_info,
)
from simple_parsing import Serializable, field
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel

from elk.utils.typing import float32_to_int16
from torch import Tensor
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput

# import torch.nn.functional as F
from ..utils import (
assert_type,
convert_span,
float32_to_int16,
instantiate_model,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
)
Expand Down Expand Up @@ -77,7 +82,7 @@ def execute(self):
def extract_hiddens(
cfg: "Extract",
*,
device: Union[str, torch.device] = "cpu",
device: str | torch.device = "cpu",
split_type: Literal["train", "val"] = "train",
rank: int = 0,
world_size: int = 1,
Expand All @@ -100,32 +105,18 @@ def extract_hiddens(
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
model = AutoModel.from_pretrained(
model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
# TODO: Maybe also make this configurable?
# We want to make sure the answer is never truncated
tokenizer = AutoTokenizer.from_pretrained(
cfg.model, truncation_side="left", verbose=False
)
is_enc_dec = model.config.is_encoder_decoder

# If this is an encoder-decoder model we don't need to run the decoder at all.
# Just strip it off, making the problem equivalent to a regular encoder-only model.
if is_enc_dec:
# This isn't actually *guaranteed* by HF, but it's true for all existing models
if not hasattr(model, "get_encoder") or not callable(model.get_encoder):
raise ValueError(
"Encoder-decoder model doesn't have expected get_encoder() method"
)

model = assert_type(PreTrainedModel, model.get_encoder())
has_lm_preds = is_autoregressive(model.config)
if has_lm_preds and rank == 0:
print("Model has language model head, will store predictions.")

# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
# print(f"Using {prompt_ds} variants for each dataset")

global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
Expand All @@ -134,8 +125,6 @@ def extract_hiddens(
if rank == world_size - 1:
max_examples += global_max_examples % world_size

print(f"Extracting {max_examples} examples from {prompt_ds} on {device}")

for example in islice(BalancedSampler(prompt_ds), max_examples):
num_variants = len(example["prompts"])
hidden_dict = {
Expand All @@ -148,24 +137,70 @@ def extract_hiddens(
)
for layer_idx in layer_indices
}
lm_preds = torch.empty(
num_variants,
2, # contrast pair
device=device,
dtype=torch.float32,
)
text_inputs = []

# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_inputs = []

# Iterate over answers
for j in range(2):
text = record[j]["text"]
variant_inputs.append(text)
for j, choice in enumerate(record):
text = choice["text"]

# TODO: Do something smarter than "rindex" here. Really we want to
# get the span of the answer directly from Jinja, but that doesn't
# seem possible. This approach may fail for complex templates.
answer_start = text.rindex(choice["answer"])

# Only feed question, not the answer, to the encoder for enc-dec models
if model.config.is_encoder_decoder:
# TODO: Maybe make this more generic for complex templates?
text = text[:answer_start].rstrip()
target = choice["answer"]
else:
target = None

# Record the EXACT string we fed to the model
variant_inputs.append(text)
inputs = tokenizer(
text,
return_offsets_mapping=True,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
truncation=True,
).to(device)
)

# The offset_mapping is a sorted list of (start, end) tuples. We locate
# the start of the answer in the tokenized sequence with binary search.
offsets = inputs.pop("offset_mapping").squeeze().tolist()
inputs = inputs.to(device)

# Run the forward pass
outputs = model(**inputs, output_hidden_states=True)

# Compute the log probability of the answer tokens if available
if has_lm_preds:
start, end = convert_span(
offsets, (answer_start, answer_start + len(choice["answer"]))
)
log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(
dim=-1
)
tokens = inputs.input_ids[..., start:end, None]
lm_preds[i, j] = log_p.gather(-1, tokens).sum()

elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = inputs.labels.shape[-1]
lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length

hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
Expand All @@ -190,12 +225,17 @@ def extract_hiddens(

text_inputs.append(variant_inputs)

yield dict(
out_record: dict[str, Any] = dict(
label=example["label"],
variant_ids=example["template_names"],
text_inputs=text_inputs,
**hidden_dict,
)
if has_lm_preds:
# We only need the probability of the positive example since this is binary
out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1]

yield out_record


# Dataset.from_generator wraps all the arguments in lists, so we unpack them here
Expand Down Expand Up @@ -252,6 +292,14 @@ def get_splits() -> SplitDict:
length=num_variants,
),
}

# Only add model_preds if the model is an autoregressive model
if is_autoregressive(model_cfg):
other_cols["model_preds"] = Sequence(
Value(dtype="float32"),
length=num_variants,
)

devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem)
builders = {
split_name: _GeneratorBuilder(
Expand Down
Loading

0 comments on commit 14a7c2a

Please sign in to comment.