Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binarized meta-templates; some extraction refactoring #218

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bbee489
Initial support for FEVER
norabelrose Apr 22, 2023
5ba1ddd
Start saving and fitting a reporter to the input embeddings
norabelrose Apr 22, 2023
3b1f74d
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
51ba54f
Rename layer 0 to 'input' to make it more clear
norabelrose Apr 22, 2023
544b485
Actually rename layer 0 correctly
norabelrose Apr 22, 2023
43da44e
Handle layer_stride correctly
norabelrose Apr 22, 2023
9056e00
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
756fa53
label_choices
norabelrose Apr 22, 2023
93b7ae0
Clean up train and eval commands; do transfer in sweep
norabelrose Apr 22, 2023
57d0b8b
Support INLP and split eval output into multiple CSVs
norabelrose Apr 22, 2023
228a6a0
Merge branch 'inlp' into template-filtering
norabelrose Apr 22, 2023
b086f0b
Merge branch 'inlp' into template-filtering
norabelrose Apr 25, 2023
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
15254bf
Merge main
norabelrose Apr 26, 2023
69c2d55
Tons of stuff, preparing for sciq_binary experiment
norabelrose Apr 27, 2023
960ff01
Support --binarize again
norabelrose Apr 27, 2023
c9e62ea
Partial support for truthful_qa
norabelrose Apr 27, 2023
eb71a6c
Merge branch 'main' into template-filtering
norabelrose Apr 29, 2023
88bb15e
Merge remote-tracking branch 'origin/main' into template-filtering
norabelrose Apr 29, 2023
c648ff0
Remove crap
norabelrose Apr 29, 2023
ef12130
EleutherAI/truthful_qa_mc
norabelrose Apr 29, 2023
5d60ebd
Update templates
norabelrose Apr 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up train and eval commands; do transfer in sweep
  • Loading branch information
norabelrose committed Apr 22, 2023
commit 93b7ae0d4a1a0eee1448d0a84ef908e455663a9f
70 changes: 12 additions & 58 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,36 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable

import pandas as pd
import torch
from simple_parsing.helpers import Serializable, field
from simple_parsing.helpers import field

from ..extraction.extraction import Extract
from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
from ..utils import select_usable_devices


@dataclass
class Eval(Serializable):
"""
Full specification of a reporter evaluation run.

Args:
data: Config specifying hidden states on which the reporter will be evaluated.
source: The name of the source run directory
which contains the reporters directory.
normalization: The normalization method to use. Defaults to "meanonly". See
`elk.training.preprocessing.normalize()` for details.
num_gpus: The number of GPUs to use. Defaults to -1, which means
"use all available GPUs".
skip_supervised: Whether to skip evaluation of the supervised classifier.
debug: When in debug mode, a useful log file is saved to the memorably-named
output directory. Defaults to False.
"""

data: Extract
source: str = field(positional=True)

concatenated_layer_offset: int = 0
debug: bool = False
min_gpu_mem: int | None = None
num_gpus: int = -1
out_dir: Path | None = None
class Eval(Run):
"""Full specification of a reporter evaluation run."""

source: str = field(default="", positional=True)
skip_supervised: bool = False

disable_cache: bool = field(default=False, to_dict=False)
def __post_init__(self):
assert self.source, "Must specify a source experiment."

def execute(self):
transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"
self.out_dir = transfer_dir / "+".join(self.data.prompts.datasets)

for dataset in self.data.prompts.datasets:
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()


@dataclass
class Evaluate(Run):
cfg: Eval

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> pd.DataFrame:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")

experiment_dir = elk_reporter_dir() / self.cfg.source
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
Expand All @@ -81,7 +47,7 @@ def evaluate_reporter(
}

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_supervised and lr_dir.exists():
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

Expand All @@ -91,15 +57,3 @@ def evaluate_reporter(
row_buf.append(stats_row)

return pd.DataFrame.from_records(row_buf)

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], pd.DataFrame] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)
9 changes: 6 additions & 3 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
colorize,
float32_to_int16,
infer_label_column,
infer_num_classes,
Expand Down Expand Up @@ -271,6 +272,7 @@ def extract(
cfg: "Extract",
*,
disable_cache: bool = False,
highlight_color: str = "cyan",
num_gpus: int = -1,
min_gpu_mem: int | None = None,
) -> DatasetDict:
Expand All @@ -279,10 +281,11 @@ def extract(
def get_splits() -> SplitDict:
available_splits = assert_type(SplitDict, info.splits)
train_name, val_name = select_train_val_splits(available_splits)

pretty_name = colorize(assert_type(str, info.builder_name), highlight_color)
print(
# Cyan color for dataset name
f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and"
f" '{val_name}' for validation"
f"{pretty_name}: using '{train_name}' for training "
f"and '{val_name}' for validation"
)
limit_list = cfg.prompts.max_examples

Expand Down
71 changes: 47 additions & 24 deletions elk/run.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,68 @@
import os
import random
from abc import ABC
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Literal, Union
from typing import Callable, Literal

import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
import yaml
from datasets import DatasetDict
from simple_parsing.helpers import Serializable, field
from torch import Tensor
from tqdm import tqdm

from .debug_logging import save_debug_log
from .extraction import extract
from .extraction import Extract, extract
from .files import elk_reporter_dir, memorably_named_dir
from .utils import (
assert_type,
get_dataset_name,
get_layers,
int16_to_float32,
select_train_val_splits,
select_usable_devices,
)

if TYPE_CHECKING:
from .evaluation.evaluate import Eval
from .training.train import Elicit


@dataclass
class Run(ABC):
cfg: Union["Elicit", "Eval"]
class Run(ABC, Serializable):
data: Extract
out_dir: Path | None = None
"""Directory to save results to. If None, a directory will be created
automatically."""

datasets: list[DatasetDict] = field(default_factory=list, init=False)
"""Datasets containing hidden states and labels for each layer."""

concatenated_layer_offset: int = 0
debug: bool = False
min_gpu_mem: int | None = None
num_gpus: int = -1
out_dir: Path | None = None
datasets: list[DatasetDict] = field(init=False)
disable_cache: bool = field(default=False, to_dict=False)

def __post_init__(self):
def execute(self, highlight_color: str = "cyan"):
self.datasets = [
extract(
cfg,
disable_cache=self.cfg.disable_cache,
num_gpus=self.cfg.num_gpus,
min_gpu_mem=self.cfg.min_gpu_mem,
disable_cache=self.disable_cache,
highlight_color=highlight_color,
num_gpus=self.num_gpus,
min_gpu_mem=self.min_gpu_mem,
)
for cfg in self.cfg.data.explode()
for cfg in self.data.explode()
]

if self.out_dir is None:
# Save in a memorably-named directory inside of
# ELK_REPORTER_DIR/<model_name>/<dataset_name>
ds_name = ", ".join(self.cfg.data.prompts.datasets)
root = elk_reporter_dir() / self.cfg.data.model / ds_name
ds_name = ", ".join(self.data.prompts.datasets)
root = elk_reporter_dir() / self.data.model / ds_name

self.out_dir = memorably_named_dir(root)

Expand All @@ -61,7 +72,7 @@ def __post_init__(self):

path = self.out_dir / "cfg.yaml"
with open(path, "w") as f:
self.cfg.dump_yaml(f)
self.dump_yaml(f)

path = self.out_dir / "fingerprints.yaml"
with open(path, "w") as meta_f:
Expand All @@ -75,6 +86,19 @@ def __post_init__(self):
meta_f,
)

devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem)
num_devices = len(devices)
func: Callable[[int], pd.DataFrame] = partial(
self.apply_to_layer, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)

@abstractmethod
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> pd.DataFrame:
"""Train or eval a reporter on a single layer."""

def make_reproducible(self, seed: int):
"""Make the run reproducible by setting the random seed."""

Expand Down Expand Up @@ -114,8 +138,8 @@ def prepare_data(

def concatenate(self, layers):
"""Concatenate hidden states from a previous layer."""
for layer in range(self.cfg.concatenated_layer_offset, len(layers)):
layers[layer] += [layers[layer][0] - self.cfg.concatenated_layer_offset]
for layer in range(self.concatenated_layer_offset, len(layers)):
layers[layer] += [layers[layer][0] - self.concatenated_layer_offset]

return layers

Expand All @@ -137,10 +161,9 @@ def apply_to_layers(
layers, *rest = [get_layers(ds) for ds in self.datasets]
assert all(x == layers for x in rest), "All datasets must have the same layers"

if self.cfg.concatenated_layer_offset > 0:
if self.concatenated_layer_offset > 0:
layers = self.concatenate(layers)

# Should we write to different CSV files for elicit vs eval?
ctx = mp.get_context("spawn")
with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f:
mapper = pool.imap_unordered if num_devices > 1 else map
Expand All @@ -154,5 +177,5 @@ def apply_to_layers(
if df_buf:
df = pd.concat(df_buf).sort_values(by="layer")
df.round(4).to_csv(f, index=False)
if self.cfg.debug:
if self.debug:
save_debug_log(self.datasets, self.out_dir)
3 changes: 1 addition & 2 deletions elk/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .classifier import Classifier
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .normalizer import Normalizer
from .reporter import OptimConfig, Reporter, ReporterConfig
from .reporter import Reporter, ReporterConfig

__all__ = [
"CcsReporter",
Expand All @@ -11,7 +11,6 @@
"EigenReporter",
"EigenReporterConfig",
"Normalizer",
"OptimConfig",
"Reporter",
"ReporterConfig",
]
2 changes: 1 addition & 1 deletion elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def fit_streaming(self, truncated: bool = False) -> float:
)

if truncated:
L, Q = truncated_eigh(A, k=self.config.num_heads)
L, Q = truncated_eigh(A, k=self.config.num_heads, seed=self.config.seed)
else:
try:
L, Q = torch.linalg.eigh(A)
Expand Down
21 changes: 1 addition & 20 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional
from typing import Optional

import torch
import torch.nn as nn
Expand All @@ -21,25 +21,6 @@ class ReporterConfig(Serializable):
seed: int = 42


@dataclass
class OptimConfig(Serializable):
"""
Args:
lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.
Defaults to 1e-2.
num_epochs: The number of epochs to train for. Defaults to 1000.
num_tries: The number of times to try training the reporter. Defaults to 10.
optimizer: The optimizer to use. Defaults to "adam".
weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01.
"""

lr: float = 1e-2
num_epochs: int = 1000
num_tries: int = 10
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
weight_decay: float = 0.01


class Reporter(nn.Module, ABC):
"""An ELK reporter network."""

Expand Down
Loading