Skip to content

Commit

Permalink
Add hyperparamter sweep to sweep.py; Fall back to eig when eigh
Browse files Browse the repository at this point in the history
… fails (#235)

* reduced reporter filesize by 4x; still unsure why the pickle file stores 1 remaining cov matrix

* add save_reporter_stats CLA

* sweep hparam_step working

* fix out_dir naming, skipping supervised

* use regular eig as backup for eigh

* split hparam directories vertically

* remove duplicate line

* fix num_variants arg

* also pass num_shots into prompt_loader

* Remove superfluous assert

---------

Co-authored-by: Nora Belrose <[email protected]>
  • Loading branch information
AlexTMallen and norabelrose committed May 5, 2023
1 parent 8fff559 commit 2509092
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 36 deletions.
Empty file modified .pre-commit-config.yaml
100644 → 100755
Empty file.
24 changes: 14 additions & 10 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,20 @@ def fit_streaming(self, truncated: bool = False) -> float:
else:
try:
L, Q = torch.linalg.eigh(A)
except torch.linalg.LinAlgError as e:
# Check if the matrix has non-finite values
if not A.isfinite().all():
raise ValueError(
"Fitting the reporter failed because the VINC matrix has "
"non-finite entries. Usually this means the hidden states "
"themselves had non-finite values."
) from e
else:
raise e
except torch.linalg.LinAlgError:
try:
L, Q = torch.linalg.eig(A)
L, Q = L.real, Q.real
except torch.linalg.LinAlgError as e:
# Check if the matrix has non-finite values
if not A.isfinite().all():
raise ValueError(
"Fitting the reporter failed because the VINC matrix has "
"non-finite entries. Usually this means the hidden states "
"themselves had non-finite values."
) from e
else:
raise e

L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :]

Expand Down
91 changes: 65 additions & 26 deletions elk/training/sweep.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from dataclasses import InitVar, dataclass, replace

import numpy as np
import torch

from ..evaluation import Eval
from ..extraction import Extract
from ..files import elk_reporter_dir, memorably_named_dir
from ..training.eigen_reporter import EigenReporterConfig
from ..utils import colorize
from .train import Elicit

Expand All @@ -19,6 +23,12 @@ class Sweep:

add_pooled: InitVar[bool] = False
"""Whether to add a dataset that pools all of the other datasets together."""
hparam_step: float = -1.0
"""The step size for hyperparameter sweeps. Performs a 2D
sweep over a and b in (var_weight, inv_weight, neg_cov_weight) = (a, 1 - b, b)
If negative, no hyperparameter sweeps will be performed. Only valid for Eigen."""
skip_transfer_eval: bool = False
"""Whether to perform transfer eval on every pair of datasets."""

name: str | None = None

Expand All @@ -35,6 +45,13 @@ def __post_init__(self, add_pooled: bool):
raise ValueError("No datasets specified")
if not self.models:
raise ValueError("No models specified")
# can only use hparam_step if we're using an eigen net
if self.hparam_step > 0 and not isinstance(
self.run_template.net, EigenReporterConfig
):
raise ValueError("Can only use hparam_step with EigenReporterConfig")
elif self.hparam_step > 1:
raise ValueError("hparam_step must be in [0, 1]")

# Check for the magic dataset "burns" which is a shortcut for all of the
# datasets used in Burns et al., except Story Cloze, which is not available
Expand Down Expand Up @@ -89,38 +106,60 @@ def execute(self):
}
)

step = self.hparam_step
weights = np.arange(0.0, 1.0 + step, step) if step > 0 else [None]

for i, model in enumerate(self.models):
print(colorize(f"===== {model} ({i + 1} of {M}) =====", "magenta"))

for dataset_str in self.datasets:
out_dir = sweep_dir / model / dataset_str

# Allow for multiple datasets to be specified in a single string with
# plus signs. This means we can pool datasets together inside of a
# single sweep.
train_datasets = tuple(ds.strip() for ds in dataset_str.split("+"))

data = replace(
self.run_template.data, model=model, datasets=train_datasets
)
run = replace(self.run_template, data=data, out_dir=out_dir)
run.execute()

if len(eval_datasets) > 1:
print(colorize("== Transfer eval ==", "green"))

# Now evaluate the reporter on the other datasets
for eval_dataset in eval_datasets:
# We already evaluated on this one during training
if eval_dataset in train_datasets:
continue

assert run.out_dir is not None
eval = Eval(
data=replace(run.data, model=model, datasets=(eval_dataset,)),
source=run.out_dir,
out_dir=out_dir / "transfer" / eval_dataset,
num_gpus=run.num_gpus,
min_gpu_mem=run.min_gpu_mem,
)
eval.execute(highlight_color="green")
for var_weight in weights:
for neg_cov_weight in weights:
out_dir = sweep_dir / model / dataset_str

data = replace(
self.run_template.data, model=model, datasets=train_datasets
)
run = replace(self.run_template, data=data, out_dir=out_dir)
if var_weight is not None and neg_cov_weight is not None:
assert isinstance(run.net, EigenReporterConfig)
run.net.var_weight = var_weight
run.net.neg_cov_weight = neg_cov_weight

# Add hyperparameter values to output directory if needed
out_dir /= f"var_weight={var_weight}"
out_dir /= f"neg_cov_weight={neg_cov_weight}"

try:
run.execute()
except torch._C._LinAlgError as e: # type: ignore
print(colorize(f"LinAlgError: {e}", "red"))
continue

if not self.skip_transfer_eval:
if len(eval_datasets) > 1:
print(colorize("== Transfer eval ==", "green"))

# Now evaluate the reporter on the other datasets
for eval_dataset in eval_datasets:
# We already evaluated on this one during training
if eval_dataset in train_datasets:
continue

assert run.out_dir is not None
eval = Eval(
data=replace(
run.data, model=model, datasets=(eval_dataset,)
),
source=run.out_dir,
out_dir=out_dir / "transfer" / eval_dataset,
num_gpus=run.num_gpus,
min_gpu_mem=run.min_gpu_mem,
skip_supervised=run.supervised == "none",
)
eval.execute(highlight_color="green")

0 comments on commit 2509092

Please sign in to comment.