Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hyperparamter sweep to sweep.py; Fall back to eig when eigh fails #235

Merged
merged 15 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified .pre-commit-config.yaml
100644 → 100755
Empty file.
24 changes: 14 additions & 10 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,20 @@ def fit_streaming(self, truncated: bool = False) -> float:
else:
try:
L, Q = torch.linalg.eigh(A)
except torch.linalg.LinAlgError as e:
# Check if the matrix has non-finite values
if not A.isfinite().all():
raise ValueError(
"Fitting the reporter failed because the VINC matrix has "
"non-finite entries. Usually this means the hidden states "
"themselves had non-finite values."
) from e
else:
raise e
except torch.linalg.LinAlgError:
try:
L, Q = torch.linalg.eig(A)
L, Q = L.real, Q.real
except torch.linalg.LinAlgError as e:
# Check if the matrix has non-finite values
if not A.isfinite().all():
raise ValueError(
"Fitting the reporter failed because the VINC matrix has "
"non-finite entries. Usually this means the hidden states "
"themselves had non-finite values."
) from e
else:
raise e

L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :]

Expand Down
91 changes: 65 additions & 26 deletions elk/training/sweep.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from dataclasses import InitVar, dataclass, replace

import numpy as np
import torch

from ..evaluation import Eval
from ..extraction import Extract
from ..files import elk_reporter_dir, memorably_named_dir
from ..training.eigen_reporter import EigenReporterConfig
from ..utils import colorize
from .train import Elicit

Expand All @@ -19,6 +23,12 @@ class Sweep:

add_pooled: InitVar[bool] = False
"""Whether to add a dataset that pools all of the other datasets together."""
hparam_step: float = -1.0
"""The step size for hyperparameter sweeps. Performs a 2D
sweep over a and b in (var_weight, inv_weight, neg_cov_weight) = (a, 1 - b, b)
If negative, no hyperparameter sweeps will be performed. Only valid for Eigen."""
skip_transfer_eval: bool = False
"""Whether to perform transfer eval on every pair of datasets."""

name: str | None = None

Expand All @@ -35,6 +45,13 @@ def __post_init__(self, add_pooled: bool):
raise ValueError("No datasets specified")
if not self.models:
raise ValueError("No models specified")
# can only use hparam_step if we're using an eigen net
if self.hparam_step > 0 and not isinstance(
self.run_template.net, EigenReporterConfig
):
raise ValueError("Can only use hparam_step with EigenReporterConfig")
elif self.hparam_step > 1:
raise ValueError("hparam_step must be in [0, 1]")

# Check for the magic dataset "burns" which is a shortcut for all of the
# datasets used in Burns et al., except Story Cloze, which is not available
Expand Down Expand Up @@ -89,38 +106,60 @@ def execute(self):
}
)

step = self.hparam_step
weights = np.arange(0.0, 1.0 + step, step) if step > 0 else [None]

for i, model in enumerate(self.models):
print(colorize(f"===== {model} ({i + 1} of {M}) =====", "magenta"))

for dataset_str in self.datasets:
out_dir = sweep_dir / model / dataset_str

# Allow for multiple datasets to be specified in a single string with
# plus signs. This means we can pool datasets together inside of a
# single sweep.
train_datasets = tuple(ds.strip() for ds in dataset_str.split("+"))

data = replace(
self.run_template.data, model=model, datasets=train_datasets
)
run = replace(self.run_template, data=data, out_dir=out_dir)
run.execute()

if len(eval_datasets) > 1:
print(colorize("== Transfer eval ==", "green"))

# Now evaluate the reporter on the other datasets
for eval_dataset in eval_datasets:
# We already evaluated on this one during training
if eval_dataset in train_datasets:
continue

assert run.out_dir is not None
eval = Eval(
data=replace(run.data, model=model, datasets=(eval_dataset,)),
source=run.out_dir,
out_dir=out_dir / "transfer" / eval_dataset,
num_gpus=run.num_gpus,
min_gpu_mem=run.min_gpu_mem,
)
eval.execute(highlight_color="green")
for var_weight in weights:
for neg_cov_weight in weights:
out_dir = sweep_dir / model / dataset_str

data = replace(
self.run_template.data, model=model, datasets=train_datasets
)
run = replace(self.run_template, data=data, out_dir=out_dir)
if var_weight is not None and neg_cov_weight is not None:
assert isinstance(run.net, EigenReporterConfig)
run.net.var_weight = var_weight
run.net.neg_cov_weight = neg_cov_weight

# Add hyperparameter values to output directory if needed
out_dir /= f"var_weight={var_weight}"
out_dir /= f"neg_cov_weight={neg_cov_weight}"

try:
run.execute()
except torch._C._LinAlgError as e: # type: ignore
print(colorize(f"LinAlgError: {e}", "red"))
continue

if not self.skip_transfer_eval:
if len(eval_datasets) > 1:
print(colorize("== Transfer eval ==", "green"))

# Now evaluate the reporter on the other datasets
for eval_dataset in eval_datasets:
# We already evaluated on this one during training
if eval_dataset in train_datasets:
continue

assert run.out_dir is not None
eval = Eval(
data=replace(
run.data, model=model, datasets=(eval_dataset,)
),
source=run.out_dir,
out_dir=out_dir / "transfer" / eval_dataset,
num_gpus=run.num_gpus,
min_gpu_mem=run.min_gpu_mem,
skip_supervised=run.supervised == "none",
)
eval.execute(highlight_color="green")