Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binarized meta-templates; some extraction refactoring #218

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bbee489
Initial support for FEVER
norabelrose Apr 22, 2023
5ba1ddd
Start saving and fitting a reporter to the input embeddings
norabelrose Apr 22, 2023
3b1f74d
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
51ba54f
Rename layer 0 to 'input' to make it more clear
norabelrose Apr 22, 2023
544b485
Actually rename layer 0 correctly
norabelrose Apr 22, 2023
43da44e
Handle layer_stride correctly
norabelrose Apr 22, 2023
9056e00
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
756fa53
label_choices
norabelrose Apr 22, 2023
93b7ae0
Clean up train and eval commands; do transfer in sweep
norabelrose Apr 22, 2023
57d0b8b
Support INLP and split eval output into multiple CSVs
norabelrose Apr 22, 2023
228a6a0
Merge branch 'inlp' into template-filtering
norabelrose Apr 22, 2023
b086f0b
Merge branch 'inlp' into template-filtering
norabelrose Apr 25, 2023
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
15254bf
Merge main
norabelrose Apr 26, 2023
69c2d55
Tons of stuff, preparing for sciq_binary experiment
norabelrose Apr 27, 2023
960ff01
Support --binarize again
norabelrose Apr 27, 2023
c9e62ea
Partial support for truthful_qa
norabelrose Apr 27, 2023
eb71a6c
Merge branch 'main' into template-filtering
norabelrose Apr 29, 2023
88bb15e
Merge remote-tracking branch 'origin/main' into template-filtering
norabelrose Apr 29, 2023
c648ff0
Remove crap
norabelrose Apr 29, 2023
ef12130
EleutherAI/truthful_qa_mc
norabelrose Apr 29, 2023
5d60ebd
Update templates
norabelrose Apr 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from simple_parsing.helpers import field

from ..files import elk_reporter_dir
from ..files import elk_reporter_dir, transfer_eval_directory
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
Expand All @@ -21,8 +21,12 @@ class Eval(Run):
def __post_init__(self):
assert self.source, "Must specify a source experiment."

transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"
self.out_dir = transfer_dir / "+".join(self.data.prompts.datasets)
# Set the output directory to the transfer directory if it's not specified
self.out_dir = (
transfer_eval_directory(self.source)
if self.out_dir is None
else self.out_dir
)

def execute(self, highlight_color: str = "cyan"):
return super().execute(highlight_color, split_type="val")
Expand All @@ -44,21 +48,32 @@ def apply_to_layer(
for ds_name, (val_h, val_gt, _) in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

val_result = evaluate_preds(val_gt, reporter(val_h))
row_bufs["eval"].append({**meta, **val_result.to_dict()})

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
lr_result = evaluate_preds(val_gt, model(val_h))
row_bufs["lr_eval"].append(
{"inlp_iter": i, **meta, **lr_result.to_dict()}
)
val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
6 changes: 3 additions & 3 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def extract_hiddens(
input_ids = torch.cat([input_ids, answer], dim=-1)
if max_len := tokenizer.model_max_length:
cur_len = input_ids.shape[-1]
input_ids = input_ids[..., -min(max_len, cur_len) :]
input_ids = input_ids[..., -min(cur_len, max_len) :]

# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids)
Expand Down Expand Up @@ -335,7 +335,7 @@ def extract(
num_gpus: int = -1,
min_gpu_mem: int | None = None,
split_type: Literal["train", "val", None] = None,
) -> DatasetDict:
) -> DatasetDictWithName:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""
info, features = hidden_features(cfg)

Expand Down Expand Up @@ -389,6 +389,6 @@ def extract(

dataset_dict = DatasetDict(ds)
return DatasetDictWithName(
name=ds_name,
name=cfg.prompts.datasets[0],
dataset=dataset_dict,
)
36 changes: 27 additions & 9 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import asdict, dataclass
from typing import Literal

import torch
from einops import repeat
Expand Down Expand Up @@ -37,33 +38,50 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
else {}
)
auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()}
return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict}
return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict}


def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult:
def evaluate_preds(
y_true: Tensor,
y_logits: Tensor,
ensembling: Literal["none", "partial", "full"] = "none",
) -> EvalResult:
"""
Evaluate the performance of a classification model.

Args:
y_true: Ground truth tensor of shape (N,).
y_pred: Predicted class tensor of shape (N, variants, n_classes).
y_logits: Predicted class tensor of shape (N, variants, n_classes).

Returns:
dict: A dictionary containing the accuracy, AUROC, and ECE.
"""
(n, v, c) = y_logits.shape
assert y_true.shape == (n,)

# Clustered bootstrap confidence intervals for AUROC
y_true = repeat(y_true, "n -> n v", v=v)
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
acc = accuracy_ci(y_true, y_logits.argmax(dim=-1))

if ensembling == "full":
y_logits = y_logits.mean(dim=1)
else:
y_true = repeat(y_true, "n -> n v", v=v)

y_pred = y_logits.argmax(dim=-1)
if ensembling == "none":
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
elif ensembling in ("partial", "full"):
# Pool together the negative and positive class logits
if c == 2:
auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0])
else:
auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits)
else:
raise ValueError(f"Unknown mode: {ensembling}")

acc = accuracy_ci(y_true, y_pred)
cal_acc = None
cal_err = None

if c == 2:
pos_probs = y_logits.softmax(-1)[..., 1]
pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0])

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(y_true.float().mean())
Expand Down
6 changes: 3 additions & 3 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import torch
import torch.multiprocessing as mp
import yaml
from datasets import DatasetDict
from simple_parsing.helpers import Serializable, field
from torch import Tensor
from tqdm import tqdm

from .debug_logging import save_debug_log
from .extraction import Extract, extract
from .extraction.dataset_name import DatasetDictWithName
from .files import elk_reporter_dir, memorably_named_dir
from .utils import (
assert_type,
Expand All @@ -36,7 +36,7 @@ class Run(ABC, Serializable):
"""Directory to save results to. If None, a directory will be created
automatically."""

datasets: list[DatasetDict] = field(default_factory=list, init=False)
datasets: list[DatasetDictWithName] = field(default_factory=list, init=False)
"""Datasets containing hidden states and labels for each layer."""

concatenated_layer_offset: int = 0
Expand Down Expand Up @@ -178,7 +178,7 @@ def apply_to_layers(
finally:
# Make sure the CSVs are written even if we crash or get interrupted
for name, dfs in df_buffers.items():
df = pd.concat(dfs).sort_values(by="layer")
df = pd.concat(dfs).sort_values(by=["layer", "ensembling"])

# Rename layer 0 to "input" to make it more clear
df["layer"].replace(0, "input", inplace=True)
Expand Down
9 changes: 5 additions & 4 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def execute(self):
if eval_dataset in train_datasets:
continue

data = deepcopy(run.data)
data.model = model_str
data.prompts.datasets = [eval_dataset]

eval = Eval(
data=Extract(
model=model_str,
prompts=PromptConfig(datasets=[eval_dataset]),
),
data=data,
source=str(run.out_dir),
out_dir=out_dir,
)
Expand Down
47 changes: 29 additions & 18 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,35 @@ def apply_to_layer(
for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items():
meta = {"dataset": ds_name, "layer": layer}

val_result = evaluate_preds(val_gt, reporter(val_h))
row_bufs["eval"].append(
{
**meta,
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
**val_result.to_dict(),
}
)

if val_lm_preds is not None:
lm_result = evaluate_preds(val_gt, val_lm_preds)
row_bufs["lm_eval"].append({**meta, **lm_result.to_dict()})

for i, model in enumerate(lr_models):
lr_result = evaluate_preds(val_gt, model(val_h))
row_bufs["lr_eval"].append(
{"inlp_iter": i, **meta, **lr_result.to_dict()}
val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
}
)

if val_lm_preds is not None:
row_bufs["lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
}
)

for i, model in enumerate(lr_models):
row_bufs["lr_eval"].append(
{
**meta,
"ensembling": mode,
"inlp_iter": i,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
8 changes: 4 additions & 4 deletions elk/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def assert_type(typ: Type[T], obj: Any) -> T:

def float32_to_int16(x: torch.Tensor) -> torch.Tensor:
"""Converts float32 to float16, then reinterprets as int16."""
fp16 = x.type(torch.float16)
if not fp16.isfinite().all():
raise ValueError("Tensor contains non-finite values!")
downcast = x.type(torch.float16)
if not downcast.isfinite().all():
raise ValueError("Cannot convert to 16 bit: values are not finite")

return fp16.view(torch.int16)
return downcast.view(torch.int16)


def int16_to_float32(x: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dev = [
"hypothesis",
"pre-commit",
"pytest",
"pyright",
"pyright==1.1.304",
"scikit-learn",
]

Expand Down