Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi datasets #123

Merged
merged 46 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
681698d
add multiple datasets support
Benw8888 Mar 9, 2023
ac1b9f1
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 9, 2023
b864c77
train_reporter works on a list of layers now
Benw8888 Mar 10, 2023
7d7d97c
changing printed layer names
Benw8888 Mar 10, 2023
4fe61e9
fixed concatenation bug
Benw8888 Mar 11, 2023
fe61d67
minor edits
Benw8888 Mar 13, 2023
74da878
fixed pyright issues
Benw8888 Mar 13, 2023
569ef05
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 14, 2023
b62b679
Merge branch 'main' into multi-datasets
norabelrose Mar 20, 2023
fe94c22
Fix tests
norabelrose Mar 20, 2023
bba24d8
Now working sorta
norabelrose Mar 22, 2023
03ba6e0
Skip slow BalancedBatchSampler test
norabelrose Mar 22, 2023
15ab351
Slightly relax test_output_is_roughly_balanced
norabelrose Mar 22, 2023
a80369e
Make BalancedSampler deterministic
norabelrose Mar 22, 2023
d304ab3
InitVar
norabelrose Mar 22, 2023
761c82d
Support multi class again
norabelrose Mar 22, 2023
f29743b
Fix naming issue
norabelrose Mar 22, 2023
b7b7e23
Support few shot prompts
norabelrose Mar 23, 2023
1afb563
Merge branch 'main' into multi-datasets
norabelrose Mar 23, 2023
225d4c7
fix multiclass labels
AlexTMallen Mar 23, 2023
9368dc8
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
AlexTMallen Mar 23, 2023
a858b65
Merge branch 'main' into multi-datasets
norabelrose Mar 24, 2023
5dc2ec6
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
norabelrose Mar 24, 2023
b1b95e5
Fix dumb part of test failures
norabelrose Mar 25, 2023
ee3911e
Fix assert_allclose warning
norabelrose Mar 25, 2023
a55b3de
Switch to torch.testing.assert_close in EigenReporter test
norabelrose Mar 25, 2023
44dc25c
Shuffle load_prompts output by default
norabelrose Mar 25, 2023
93d8d87
Fix smoke test failure
norabelrose Mar 25, 2023
fad4d74
Remove debug prints
AlexTMallen Mar 25, 2023
0a054f4
Remove more debug print statements
AlexTMallen Mar 25, 2023
177eec2
make min_memory usable; broadcast mmax_examples in __post_init__
AlexTMallen Mar 26, 2023
3a762b0
prompt loading refactor to enable better streaming
AlexTMallen Mar 26, 2023
f66c054
remove shuffle arg
AlexTMallen Mar 26, 2023
d3d87fc
remove unused @dataclass
lauritowal Mar 26, 2023
3d08147
merge
lauritowal Mar 27, 2023
c9a43e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2023
94290aa
add concatenated_layer_offset to eval
lauritowal Mar 27, 2023
f9298e4
Merge branch 'multi-datasets' of https://github.com/EleutherAI/elk in…
lauritowal Mar 27, 2023
3765c4f
add self.
lauritowal Mar 27, 2023
2b05193
replace target with data
lauritowal Mar 27, 2023
83731bb
add self.
lauritowal Mar 27, 2023
764fda9
remove second arg
lauritowal Mar 27, 2023
d2c66b0
fix passing the wrong params for world size / rank
thejaminator Mar 28, 2023
9186326
Update prompt_loading.py
lauritowal Mar 28, 2023
3f99a4d
fix pre-commit errors
lauritowal Mar 28, 2023
148130d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
merge
  • Loading branch information
lauritowal committed Mar 27, 2023
commit 3d08147835eb22ca913cdc8370c7128cd08604a1
2 changes: 1 addition & 1 deletion elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .extraction import extract_hiddens, ExtractionConfig
from .extraction import extract_hiddens, Extract
71 changes: 18 additions & 53 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,32 @@
"""Main entry point for `elk`."""

from .extraction import extract, ExtractionConfig
from .evaluation.evaluate import EvaluateConfig, evaluate_reporters
from .training import RunConfig
from .training.train import train
from dataclasses import dataclass
from pathlib import Path
from typing import Union

from simple_parsing import ArgumentParser

from elk.evaluation.evaluate import Eval
from elk.extraction.extraction import Extract
from elk.training.train import Elicit

def run():
parser = ArgumentParser(add_help=False)
subparsers = parser.add_subparsers(dest="command", required=True)

extract_parser = subparsers.add_parser(
"extract", help="Extract hidden states from a model."
)
extract_parser.add_arguments(ExtractionConfig, dest="extraction")
extract_parser.add_argument(
"--output",
"-o",
type=Path,
help="Path to save hidden states to.",
required=True,
)
extract_parser.add_argument(
"--num_gpus",
type=int,
help="Maximum number of GPUs to use.",
required=False,
default=-1,
)
@dataclass
class Command:
"""Some top-level command"""

elicit_parser = subparsers.add_parser(
"elicit",
help=(
"Extract and train a set of ELK reporters "
"on hidden states from `elk extract`. "
),
conflict_handler="resolve",
)
elicit_parser.add_arguments(RunConfig, dest="run")
elicit_parser.add_argument(
"--output",
"-o",
type=Path,
help="Path to save checkpoints to.",
)
command: Union[Elicit, Eval, Extract]

subparsers.add_parser(
"eval",
help="Evaluate a set of ELK reporters generated by `elk train`.",
).add_arguments(EvaluateConfig, dest="eval")
def execute(self):
return self.command.execute()

args = parser.parse_args()

if args.command == "extract":
extract(args.extraction, args.num_gpus).save_to_disk(args.output)
elif args.command == "elicit":
train(args.run, args.output)
elif args.command == "eval":
evaluate_reporters(args.eval)
else:
raise ValueError(f"Unknown command {args.command}")
def run():
parser = ArgumentParser(add_help=False)
parser.add_arguments(Command, dest="run")
args = parser.parse_args()
run: Command = args.run
run.execute()


if __name__ == "__main__":
Expand Down
155 changes: 75 additions & 80 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from simple_parsing.helpers import Serializable, field
from torch import Tensor
from tqdm.auto import tqdm
from typing import Literal, Optional, cast
from typing import Literal, Optional, cast, Callable
import csv
import os
import torch
import torch.multiprocessing as mp

from ..extraction import ExtractionConfig, extract
from elk.extraction.extraction import Extract
from ..files import elk_reporter_dir, memorably_named_dir
from ..utils import (
assert_type,
Expand All @@ -21,96 +21,91 @@
select_usable_devices,
)

import torch
from simple_parsing import Serializable, field

from elk.files import elk_reporter_dir
from elk.run import Run
from elk.training import Reporter
from elk.evaluation.evaluate_log import EvalLog
from elk.utils import select_usable_devices


@dataclass
class EvaluateConfig(Serializable):
target: ExtractionConfig
class Eval(Serializable):
"""
Full specification of a reporter evaluation run.

Args:
data: Config specifying hidden states on which the reporter will be evaluated.
source: The name of the source run directory
which contains the reporters directory.
normalization: The normalization method to use. Defaults to "meanonly". See
`elk.training.preprocessing.normalize()` for details.
num_gpus: The number of GPUs to use. Defaults to -1, which means
"use all available GPUs".
debug: When in debug mode, a useful log file is saved to the memorably-named
output directory. Defaults to False.
"""

data: Extract
source: str = field(positional=True)
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"
num_gpus: int = -1


def evaluate_reporter(
cfg: EvaluateConfig,
dataset: DatasetDict,
layer: int,
devices: list[str],
world_size: int = 1,
):
"""Evaluate a single reporter on a single layer."""
rank = os.getpid() % world_size
device = devices[rank]

# Note: currently we're just upcasting to float32 so we don't have to deal with
# grad scaling (which isn't supported for LBFGS), while the hidden states are
# saved in float16 to save disk space. In the future we could try to use mixed
# precision training in at least some cases.
with dataset.formatted_as("torch", device=device, dtype=torch.int16):
train_split, val_split = select_train_val_splits(dataset)
train, val = dataset[train_split], dataset[val_split]
test_labels = cast(Tensor, val["label"])

_, test_h = normalize(
int16_to_float32(assert_type(Tensor, train[f"hidden_{layer}"])),
int16_to_float32(assert_type(Tensor, val[f"hidden_{layer}"])),
method=cfg.normalization,
)

reporter_path = elk_reporter_dir() / cfg.source / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

test_x0, test_x1 = test_h.unbind(dim=-2)

test_result = reporter.score(test_labels, test_x0, test_x1)

stats = [layer, *test_result]
return stats
debug: bool = False
out_dir: Optional[Path] = None
num_gpus: int = -1

def execute(self):
transfer_eval = elk_reporter_dir() / self.source / "transfer_eval"

def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None):
ds = extract(cfg.target, num_gpus=cfg.num_gpus)
run = Evaluate(cfg=self, out_dir=transfer_eval)
run.evaluate()

layers = [
int(feat[len("hidden_") :])
for feat in ds["train"].features
if feat.startswith("hidden_")
]

devices = select_usable_devices(cfg.num_gpus, min_memory=cfg.target.min_gpu_mem)
num_devices = len(devices)
@dataclass
class Evaluate(Run):
cfg: Eval

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> EvalLog:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, _, test_x0, test_x1, _, test_labels = self.prepare_data(
device,
layer,
)

transfer_eval = elk_reporter_dir() / cfg.source / "transfer_eval"
transfer_eval.mkdir(parents=True, exist_ok=True)
reporter_path = (
elk_reporter_dir() / self.cfg.source / "reporters" / f"layer_{layer}.pt"
)
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

if out_dir is None:
out_dir = memorably_named_dir(transfer_eval)
else:
out_dir.mkdir(parents=True, exist_ok=True)
test_result = reporter.score(
test_labels,
test_x0,
test_x1,
)

# Print the output directory in bold with escape codes
print(f"Saving results to \033[1m{out_dir}\033[0m")
return EvalLog(
layer=layer,
eval_result=test_result,
)

with open(out_dir / "cfg.yaml", "w") as f:
cfg.dump_yaml(f)
def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(self.cfg.num_gpus, min_memory=cfg.target.min_gpu_mem)

cols = ["layer", "loss", "acc", "cal_acc", "auroc"]
# Evaluate reporters for each layer in parallel
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
fn = partial(
evaluate_reporter, cfg, ds, devices=devices, world_size=num_devices
num_devices = len(devices)
func: Callable[[int], EvalLog] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(
func=func,
num_devices=num_devices,
to_csv_line=lambda item: item.to_csv_line(),
csv_columns=EvalLog.csv_columns(),
)
writer = csv.writer(f)
writer.writerow(cols)

mapper = pool.imap_unordered if num_devices > 1 else map
row_buf = []
try:
for i, *stats in tqdm(mapper(fn, layers), total=len(layers)):
row_buf.append([i] + [f"{s:.4f}" for s in stats])
finally:
# Make sure the CSV is written even if we crash or get interrupted
for row in sorted(row_buf):
writer.writerow(row)

print("Results saved")
27 changes: 27 additions & 0 deletions elk/evaluation/evaluate_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass

from elk.training.reporter import EvalResult


@dataclass
class EvalLog:
"""The result of running eval on a layer of a dataset"""

layer: int
eval_result: EvalResult

@staticmethod
def csv_columns() -> list[str]:
return ["layer", "acc", "cal_acc", "auroc", "ece"]

def to_csv_line(self) -> list[str]:
items = [
self.layer,
self.eval_result.acc,
self.eval_result.cal_acc,
self.eval_result.auroc,
self.eval_result.ece,
]
return [
f"{item:.4f}" if isinstance(item, float) else str(item) for item in items
]
4 changes: 2 additions & 2 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import ExtractionConfig, extract_hiddens, extract
from .extraction import Extract, extract_hiddens, extract
from .generator import _GeneratorConfig, _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
from .prompt_loading import PromptConfig, load_prompts
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.