Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binarized meta-templates; some extraction refactoring #218

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bbee489
Initial support for FEVER
norabelrose Apr 22, 2023
5ba1ddd
Start saving and fitting a reporter to the input embeddings
norabelrose Apr 22, 2023
3b1f74d
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
51ba54f
Rename layer 0 to 'input' to make it more clear
norabelrose Apr 22, 2023
544b485
Actually rename layer 0 correctly
norabelrose Apr 22, 2023
43da44e
Handle layer_stride correctly
norabelrose Apr 22, 2023
9056e00
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
756fa53
label_choices
norabelrose Apr 22, 2023
93b7ae0
Clean up train and eval commands; do transfer in sweep
norabelrose Apr 22, 2023
57d0b8b
Support INLP and split eval output into multiple CSVs
norabelrose Apr 22, 2023
228a6a0
Merge branch 'inlp' into template-filtering
norabelrose Apr 22, 2023
b086f0b
Merge branch 'inlp' into template-filtering
norabelrose Apr 25, 2023
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
15254bf
Merge main
norabelrose Apr 26, 2023
69c2d55
Tons of stuff, preparing for sciq_binary experiment
norabelrose Apr 27, 2023
960ff01
Support --binarize again
norabelrose Apr 27, 2023
c9e62ea
Partial support for truthful_qa
norabelrose Apr 27, 2023
eb71a6c
Merge branch 'main' into template-filtering
norabelrose Apr 29, 2023
88bb15e
Merge remote-tracking branch 'origin/main' into template-filtering
norabelrose Apr 29, 2023
c648ff0
Remove crap
norabelrose Apr 29, 2023
ef12130
EleutherAI/truthful_qa_mc
norabelrose Apr 29, 2023
5d60ebd
Update templates
norabelrose Apr 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support INLP and split eval output into multiple CSVs
  • Loading branch information
norabelrose committed Apr 22, 2023
commit 57d0b8b754d7c856794b4e76ae1e6542fe0c2102
34 changes: 18 additions & 16 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from dataclasses import dataclass

import pandas as pd
Expand Down Expand Up @@ -25,7 +26,7 @@ def __post_init__(self):

def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> pd.DataFrame:
) -> dict[str, pd.DataFrame]:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")
Expand All @@ -36,24 +37,25 @@ def apply_to_layer(
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

row_buf = []
row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, _) in val_output.items():
val_result = evaluate_preds(val_gt, reporter(val_h))
meta = {"dataset": ds_name, "layer": layer}

stats_row = {
"dataset": ds_name,
"layer": layer,
**val_result.to_dict(),
}
val_result = evaluate_preds(val_gt, reporter(val_h))
row_bufs["eval"].append({**meta, **val_result.to_dict()})

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_result = evaluate_preds(val_gt, lr_model(val_h))
stats_row.update(lr_result.to_dict(prefix="lr_"))

row_buf.append(stats_row)

return pd.DataFrame.from_records(row_buf)
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
lr_result = evaluate_preds(val_gt, model(val_h))
row_bufs["lr_eval"].append(
{"inlp_iter": i, **meta, **lr_result.to_dict()}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
24 changes: 13 additions & 11 deletions elk/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import random
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -88,15 +89,15 @@ def execute(self, highlight_color: str = "cyan"):

devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem)
num_devices = len(devices)
func: Callable[[int], pd.DataFrame] = partial(
func: Callable[[int], dict[str, pd.DataFrame]] = partial(
self.apply_to_layer, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)

@abstractmethod
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> pd.DataFrame:
) -> dict[str, pd.DataFrame]:
"""Train or eval a reporter on a single layer."""

def make_reproducible(self, seed: int):
Expand Down Expand Up @@ -145,7 +146,7 @@ def concatenate(self, layers):

def apply_to_layers(
self,
func: Callable[[int], pd.DataFrame],
func: Callable[[int], dict[str, pd.DataFrame]],
num_devices: int,
):
"""Apply a function to each layer of the datasets in parallel
Expand All @@ -165,17 +166,18 @@ def apply_to_layers(
layers = self.concatenate(layers)

ctx = mp.get_context("spawn")
with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f:
with ctx.Pool(num_devices) as pool:
mapper = pool.imap_unordered if num_devices > 1 else map
df_buf = []
df_buffers = defaultdict(list)

try:
for df in tqdm(mapper(func, layers), total=len(layers)):
df_buf.append(df)
for df_dict in tqdm(mapper(func, layers), total=len(layers)):
for k, v in df_dict.items():
df_buffers[k].append(v)
finally:
# Make sure the CSV is written even if we crash or get interrupted
if df_buf:
df = pd.concat(df_buf).sort_values(by="layer")
df.round(4).to_csv(f, index=False)
# Make sure the CSVs are written even if we crash or get interrupted
for name, dfs in df_buffers.items():
df = pd.concat(dfs).sort_values(by="layer")
df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False)
if self.debug:
save_debug_log(self.datasets, self.out_dir)
60 changes: 58 additions & 2 deletions elk/training/classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

import torch
from torch import Tensor
Expand All @@ -10,6 +10,14 @@
)


@dataclass
class InlpResult:
"""Result of Iterative Nullspace Projection (NLP)."""

losses: list[float] = field(default_factory=list)
classifiers: list["Classifier"] = field(default_factory=list)


@dataclass
class RegularizationPath:
"""Result of cross-validation."""
Expand Down Expand Up @@ -175,10 +183,58 @@ def fit_cv(
self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter)
return RegularizationPath(l2_penalties, mean_losses.tolist())

@classmethod
def inlp(
cls, x: Tensor, y: Tensor, max_iter: int | None = None, tol: float = 0.01
) -> InlpResult:
"""Iterative Nullspace Projection (INLP) <https://arxiv.org/abs/2004.07667>.

Args:
x: Input tensor of shape (N, D), where N is the number of samples and D is
the input dimension.
y: Target tensor of shape (N,) for binary classification or (N, C) for
multiclass classification, where C is the number of classes.
max_iter: Maximum number of iterations to run. If `None`, run for the full
dimension of the input.
tol: Tolerance for the loss function. The algorithm will stop when the loss
is within `tol` of the entropy of the labels.

Returns:
`InlpResult` containing the classifiers and losses achieved at each
iteration.
"""

y.shape[-1] if y.ndim > 1 else 2
d = x.shape[-1]
loss = 0.0

# Compute entropy of the labels
p = y.float().mean()
H = -p * torch.log(p) - (1 - p) * torch.log(1 - p)

if max_iter is not None:
d = min(d, max_iter)

# Iterate until the loss is within epsilon of the entropy
result = InlpResult()
for _ in range(d):
clf = cls(d, device=x.device, dtype=x.dtype)
loss = clf.fit(x, y)
result.classifiers.append(clf)
result.losses.append(loss)

if loss >= (1.0 - tol) * H:
break

# Project the data onto the nullspace of the classifier
x = clf.nullspace_project(x)

return result

def nullspace_project(self, x: Tensor) -> Tensor:
"""Project the given data onto the nullspace of the classifier."""

# https://en.wikipedia.org/wiki/Projection_(linear_algebra)
A = self.linear.weight.data.T
P = A @ torch.linalg.solve(A.mT @ A, A.mT)
return x - P @ x
return x - x @ P
19 changes: 13 additions & 6 deletions elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from .classifier import Classifier


def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier:
def train_supervised(
data: dict[str, tuple], device: str, mode: str
) -> list[Classifier]:
Xs, train_labels = [], []

for train_h, labels, _ in data.values():
Expand All @@ -19,10 +21,15 @@ def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifie
train_labels.append(labels)

X, train_labels = torch.cat(Xs), torch.cat(train_labels)
lr_model = Classifier(X.shape[-1], device=device)
if cv:
if mode == "cv":
lr_model = Classifier(X.shape[-1], device=device)
lr_model.fit_cv(X, train_labels)
else:
return [lr_model]
elif mode == "inlp":
return Classifier.inlp(X, train_labels).classifiers
elif mode == "single":
lr_model = Classifier(X.shape[-1], device=device)
lr_model.fit(X, train_labels)

return lr_model
return [lr_model]
else:
raise ValueError(f"Unknown mode: {mode}")
48 changes: 27 additions & 21 deletions elk/training/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Main training loop."""

from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
Expand Down Expand Up @@ -27,7 +28,7 @@ class Elicit(Run):
)
"""Config for building the reporter network."""

supervised: Literal["none", "single", "cv"] = "single"
supervised: Literal["none", "single", "inlp", "cv"] = "single"
"""Whether to train a supervised classifier, and if so, whether to use
cross-validation. Defaults to "single", which means to train a single classifier
on the training data. "cv" means to use cross-validation."""
Expand All @@ -47,7 +48,7 @@ def apply_to_layer(
layer: int,
devices: list[str],
world_size: int,
) -> pd.DataFrame:
) -> dict[str, pd.DataFrame]:
"""Train a single reporter on a single layer."""

self.make_reproducible(seed=self.net.seed + layer)
Expand Down Expand Up @@ -109,33 +110,38 @@ def apply_to_layer(

# Fit supervised logistic regression model
if self.supervised != "none":
lr_model = train_supervised(
train_dict, device=device, cv=self.supervised == "cv"
lr_models = train_supervised(
train_dict,
device=device,
mode=self.supervised,
)
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
torch.save(lr_model, file)
torch.save(lr_models, file)
else:
lr_model = None
lr_models = []

row_buf = []
row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items():
meta = {"dataset": ds_name, "layer": layer}

val_result = evaluate_preds(val_gt, reporter(val_h))
row = {
"dataset": ds_name,
"layer": layer,
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
**val_result.to_dict(),
}
row_bufs["eval"].append(
{
**meta,
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
**val_result.to_dict(),
}
)

if val_lm_preds is not None:
lm_result = evaluate_preds(val_gt, val_lm_preds)
row.update(lm_result.to_dict(prefix="lm_"))
row_bufs["lm_eval"].append({**meta, **lm_result.to_dict()})

if lr_model is not None:
lr_result = evaluate_preds(val_gt, lr_model(val_h))
row.update(lr_result.to_dict(prefix="lr_"))

row_buf.append(row)
for i, model in enumerate(lr_models):
lr_result = evaluate_preds(val_gt, model(val_h))
row_bufs["lr_eval"].append(
{"inlp_iter": i, **meta, **lr_result.to_dict()}
)

return pd.DataFrame.from_records(row_buf)
return {k: pd.DataFrame(v) for k, v in row_bufs.items()}