Skip to content

Commit

Permalink
add transfer eval for latest main
Browse files Browse the repository at this point in the history
  • Loading branch information
lauritowal committed Mar 10, 2023
1 parent a44fb01 commit 55edef1
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 59 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ elk elicit microsoft/deberta-v2-xxlarge-mnli imdb

This will automatically download the model and dataset, run the model and extract the relevant representations if they aren't cached on disk, fit reporters on them, and save the reporter checkpoints to the `elk-reporters` folder in your home directory. It will also evaluate the reporter classification performance on a held out test set and save it to a CSV file in the same folder.

```bash
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
```

This will evaluate the probe from the run naughty-northcutt on the hidden states extracted from the model deberta-v2-xxlarge-mnli for the imdb dataset. It will result in an `eval.csv` and `cfg.yaml` file, which are stored under a subfolder in `elk-reporters/naughty-northcutt/transfer_eval`.

## Caching

The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.
Expand Down
157 changes: 99 additions & 58 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,112 @@
import csv
import os
import pickle
import torch

from dataclasses import dataclass
from functools import partial
from hashlib import md5
from elk.training.preprocessing import load_hidden_states, normalize
from simple_parsing.helpers import field, Serializable
from typing import Literal, List
from pathlib import Path
from typing import List, Literal, Optional, cast

import torch
import torch.multiprocessing as mp
import yaml
from simple_parsing.helpers import Serializable, field
from torch import Tensor
from tqdm.auto import tqdm

from datasets import DatasetDict
from elk.training.preprocessing import normalize

from ..extraction import ExtractionConfig, extract
from ..files import elk_reporter_dir, memorably_named_dir
from ..utils import select_usable_gpus
from ..utils import assert_type, held_out_split, int16_to_float32, select_usable_devices


@dataclass
class EvaluateConfig(Serializable):
source: str
targets: List[str]
normalization: Literal["legacy", "elementwise", "meanonly"] = "meanonly"
device: str = "cuda"
source: str = field(positional=True)
target: ExtractionConfig
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"
max_gpus: int = -1


def evaluate_reporter(
cfg: EvaluateConfig,
dataset: DatasetDict,
layer: int,
devices: list[str],
world_size: int = 1,
):
"""Evaluate a single reporter on a single layer."""
rank = os.getpid() % world_size
device = devices[rank]

# Note: currently we're just upcasting to float32 so we don't have to deal with
# grad scaling (which isn't supported for LBFGS), while the hidden states are
# saved in float16 to save disk space. In the future we could try to use mixed
# precision training in at least some cases.
with dataset.formatted_as("torch", device=device, dtype=torch.int16):
train, test = dataset["train"], held_out_split(dataset)
test_labels = cast(Tensor, test["label"])

_, test_h = normalize(
int16_to_float32(assert_type(Tensor, train[f"hidden_{layer}"])),
int16_to_float32(assert_type(Tensor, test[f"hidden_{layer}"])),
method=cfg.normalization,
)

reporter_path = elk_reporter_dir() / cfg.source / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

test_x0, test_x1 = test_h.unbind(dim=-2)

def evaluate_reporters(cfg: EvaluateConfig):
for target in cfg.targets:
hiddens, labels = load_hidden_states(
path=out_dir / target / "validation_hiddens.pt"
test_result = reporter.score(
(test_x0, test_x1),
test_labels,
)

stats = [layer, *test_result]
return stats


def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None):
ds = extract(cfg.target, max_gpus=cfg.max_gpus)

layers = [
int(feat[len("hidden_") :])
for feat in ds["train"].features
if feat.startswith("hidden_")
]

devices = select_usable_devices(cfg.max_gpus)
num_devices = len(devices)

transfer_eval = elk_reporter_dir() / cfg.source / "transfer_eval"
transfer_eval.mkdir(parents=True, exist_ok=True)

if out_dir is None:
out_dir = memorably_named_dir(transfer_eval)
else:
out_dir.mkdir(parents=True, exist_ok=True)

# Print the output directory in bold with escape codes
print(f"Saving results to \033[1m{out_dir}\033[0m")

with open(out_dir / "cfg.yaml", "w") as f:
cfg.dump_yaml(f)

cols = ["layer", "loss", "acc", "cal_acc", "auroc"]
# Evaluate reporters for each layer in parallel
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
fn = partial(
evaluate_reporter, cfg, ds, devices=devices, world_size=num_devices
)
assert len(set(labels)) > 1

_, hiddens = normalize(hiddens, hiddens, cfg.normalization)

reporter_root_path = elk_reporter_dir() / cfg.source / "reporters"

transfer_eval = elk_reporter_dir() / cfg.source / "transfer_eval"
transfer_eval.mkdir(parents=True, exist_ok=True)

L = hiddens.shape[1]
layers = list(hiddens.unbind(1))
layers.reverse()
csv_file = transfer_eval / f"{target}.csv"

for path in reporter_root_path.glob("*.pt"):
reporter = torch.load(path, map_location=cfg.device)
reporter.eval()

with torch.no_grad(), open(csv_file, "w") as f:
for layer_idx, hidden_state in enumerate(layers):
x0, x1 = hidden_state.to(cfg.device).float().chunk(2, dim=-1)
result = reporter.score(
(x0, x1),
labels.to(cfg.device),
)
stats = [*result]
stats += [cfg.normalization, cfg.source, target]

writer = csv.writer(f)
if not csv_file.exists():
# write column names only once
cols = [
"layer",
"acc",
"cal_acc",
"auroc",
"normalization",
"name",
"targets",
]
writer.writerow(cols)
writer.writerow([L - layer_idx] + [stats])

print("Eval file generated: ", csv_file)
writer = csv.writer(f)
writer.writerow(cols)

mapper = pool.imap if num_devices > 1 else map
for i, *stats in tqdm(mapper(fn, layers), total=len(layers)):
writer.writerow([i] + [f"{s:.4f}" for s in stats])

print("Results saved")
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ dependencies = [
# 4.0 introduced the breaking change of using return_dict=True by default
"transformers>=4.0.0",
# required by promptsource, which doesn't specify a version
"jinja2"
"jinja2",
"pyyaml"
]
version = "0.1.1"

Expand Down

0 comments on commit 55edef1

Please sign in to comment.