forked from EleutherAI/elk
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a44fb01
commit 55edef1
Showing
3 changed files
with
107 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,112 @@ | ||
import csv | ||
import os | ||
import pickle | ||
import torch | ||
|
||
from dataclasses import dataclass | ||
from functools import partial | ||
from hashlib import md5 | ||
from elk.training.preprocessing import load_hidden_states, normalize | ||
from simple_parsing.helpers import field, Serializable | ||
from typing import Literal, List | ||
from pathlib import Path | ||
from typing import List, Literal, Optional, cast | ||
|
||
import torch | ||
import torch.multiprocessing as mp | ||
import yaml | ||
from simple_parsing.helpers import Serializable, field | ||
from torch import Tensor | ||
from tqdm.auto import tqdm | ||
|
||
from datasets import DatasetDict | ||
from elk.training.preprocessing import normalize | ||
|
||
from ..extraction import ExtractionConfig, extract | ||
from ..files import elk_reporter_dir, memorably_named_dir | ||
from ..utils import select_usable_gpus | ||
from ..utils import assert_type, held_out_split, int16_to_float32, select_usable_devices | ||
|
||
|
||
@dataclass | ||
class EvaluateConfig(Serializable): | ||
source: str | ||
targets: List[str] | ||
normalization: Literal["legacy", "elementwise", "meanonly"] = "meanonly" | ||
device: str = "cuda" | ||
source: str = field(positional=True) | ||
target: ExtractionConfig | ||
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" | ||
max_gpus: int = -1 | ||
|
||
|
||
def evaluate_reporter( | ||
cfg: EvaluateConfig, | ||
dataset: DatasetDict, | ||
layer: int, | ||
devices: list[str], | ||
world_size: int = 1, | ||
): | ||
"""Evaluate a single reporter on a single layer.""" | ||
rank = os.getpid() % world_size | ||
device = devices[rank] | ||
|
||
# Note: currently we're just upcasting to float32 so we don't have to deal with | ||
# grad scaling (which isn't supported for LBFGS), while the hidden states are | ||
# saved in float16 to save disk space. In the future we could try to use mixed | ||
# precision training in at least some cases. | ||
with dataset.formatted_as("torch", device=device, dtype=torch.int16): | ||
train, test = dataset["train"], held_out_split(dataset) | ||
test_labels = cast(Tensor, test["label"]) | ||
|
||
_, test_h = normalize( | ||
int16_to_float32(assert_type(Tensor, train[f"hidden_{layer}"])), | ||
int16_to_float32(assert_type(Tensor, test[f"hidden_{layer}"])), | ||
method=cfg.normalization, | ||
) | ||
|
||
reporter_path = elk_reporter_dir() / cfg.source / "reporters" / f"layer_{layer}.pt" | ||
reporter = torch.load(reporter_path, map_location=device) | ||
reporter.eval() | ||
|
||
test_x0, test_x1 = test_h.unbind(dim=-2) | ||
|
||
def evaluate_reporters(cfg: EvaluateConfig): | ||
for target in cfg.targets: | ||
hiddens, labels = load_hidden_states( | ||
path=out_dir / target / "validation_hiddens.pt" | ||
test_result = reporter.score( | ||
(test_x0, test_x1), | ||
test_labels, | ||
) | ||
|
||
stats = [layer, *test_result] | ||
return stats | ||
|
||
|
||
def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None): | ||
ds = extract(cfg.target, max_gpus=cfg.max_gpus) | ||
|
||
layers = [ | ||
int(feat[len("hidden_") :]) | ||
for feat in ds["train"].features | ||
if feat.startswith("hidden_") | ||
] | ||
|
||
devices = select_usable_devices(cfg.max_gpus) | ||
num_devices = len(devices) | ||
|
||
transfer_eval = elk_reporter_dir() / cfg.source / "transfer_eval" | ||
transfer_eval.mkdir(parents=True, exist_ok=True) | ||
|
||
if out_dir is None: | ||
out_dir = memorably_named_dir(transfer_eval) | ||
else: | ||
out_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# Print the output directory in bold with escape codes | ||
print(f"Saving results to \033[1m{out_dir}\033[0m") | ||
|
||
with open(out_dir / "cfg.yaml", "w") as f: | ||
cfg.dump_yaml(f) | ||
|
||
cols = ["layer", "loss", "acc", "cal_acc", "auroc"] | ||
# Evaluate reporters for each layer in parallel | ||
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: | ||
fn = partial( | ||
evaluate_reporter, cfg, ds, devices=devices, world_size=num_devices | ||
) | ||
assert len(set(labels)) > 1 | ||
|
||
_, hiddens = normalize(hiddens, hiddens, cfg.normalization) | ||
|
||
reporter_root_path = elk_reporter_dir() / cfg.source / "reporters" | ||
|
||
transfer_eval = elk_reporter_dir() / cfg.source / "transfer_eval" | ||
transfer_eval.mkdir(parents=True, exist_ok=True) | ||
|
||
L = hiddens.shape[1] | ||
layers = list(hiddens.unbind(1)) | ||
layers.reverse() | ||
csv_file = transfer_eval / f"{target}.csv" | ||
|
||
for path in reporter_root_path.glob("*.pt"): | ||
reporter = torch.load(path, map_location=cfg.device) | ||
reporter.eval() | ||
|
||
with torch.no_grad(), open(csv_file, "w") as f: | ||
for layer_idx, hidden_state in enumerate(layers): | ||
x0, x1 = hidden_state.to(cfg.device).float().chunk(2, dim=-1) | ||
result = reporter.score( | ||
(x0, x1), | ||
labels.to(cfg.device), | ||
) | ||
stats = [*result] | ||
stats += [cfg.normalization, cfg.source, target] | ||
|
||
writer = csv.writer(f) | ||
if not csv_file.exists(): | ||
# write column names only once | ||
cols = [ | ||
"layer", | ||
"acc", | ||
"cal_acc", | ||
"auroc", | ||
"normalization", | ||
"name", | ||
"targets", | ||
] | ||
writer.writerow(cols) | ||
writer.writerow([L - layer_idx] + [stats]) | ||
|
||
print("Eval file generated: ", csv_file) | ||
writer = csv.writer(f) | ||
writer.writerow(cols) | ||
|
||
mapper = pool.imap if num_devices > 1 else map | ||
for i, *stats in tqdm(mapper(fn, layers), total=len(layers)): | ||
writer.writerow([i] + [f"{s:.4f}" for s in stats]) | ||
|
||
print("Results saved") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters