Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hyperparamter sweep to sweep.py; Fall back to eig when eigh fails #235

Merged
merged 15 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use regular eig as backup for eigh
  • Loading branch information
AlexTMallen committed Apr 29, 2023
commit 4352986be22d998369bf492bd590f5ede1df5ed3
24 changes: 14 additions & 10 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,20 @@ def fit_streaming(self, truncated: bool = False) -> float:
else:
try:
L, Q = torch.linalg.eigh(A)
except torch.linalg.LinAlgError as e:
# Check if the matrix has non-finite values
if not A.isfinite().all():
raise ValueError(
"Fitting the reporter failed because the VINC matrix has "
"non-finite entries. Usually this means the hidden states "
"themselves had non-finite values."
) from e
else:
raise e
except torch.linalg.LinAlgError:
try:
L, Q = torch.linalg.eig(A)
L, Q = L.real, Q.real
except torch.linalg.LinAlgError as e:
# Check if the matrix has non-finite values
if not A.isfinite().all():
raise ValueError(
"Fitting the reporter failed because the VINC matrix has "
"non-finite entries. Usually this means the hidden states "
"themselves had non-finite values."
) from e
else:
raise e

L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :]

Expand Down
52 changes: 30 additions & 22 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import InitVar, dataclass

import numpy as np
import torch

from ..evaluation.evaluate import Eval
from ..extraction import Extract, PromptConfig
Expand All @@ -25,6 +26,8 @@ class Sweep:
"""The step size for hyperparameter sweeps. Performs a 2D
sweep over a and b in (var_weight, inv_weight, neg_cov_weight) = (a, 1 - b, b)
If negative, no hyperparameter sweeps will be performed. Only valid for Eigen."""
skip_transfer_eval: bool = False
"""Whether to perform transfer eval on every pair of datasets."""

name: str | None = None

Expand Down Expand Up @@ -108,25 +111,30 @@ def execute(self):
)

run.out_dir = out_dir
run.execute()

if len(eval_datasets) > 1:
print(colorize("== Transfer eval ==", "green"))

# Now evaluate the reporter on the other datasets
for eval_dataset in eval_datasets:
# We already evaluated on this one during training
if eval_dataset in train_datasets:
continue

data = deepcopy(run.data)
data.model = model_str
data.prompts.datasets = [eval_dataset]

eval = Eval(
data=data,
source=str(run.out_dir),
out_dir=out_dir,
skip_supervised=run.supervised == "none",
)
eval.execute(highlight_color="green")
try:
run.execute()
except torch._C._LinAlgError as e: # type: ignore
print(colorize(f"LinAlgError: {e}", "red"))
continue

if not self.skip_transfer_eval:
if len(eval_datasets) > 1:
print(colorize("== Transfer eval ==", "green"))

# Now evaluate the reporter on the other datasets
for eval_dataset in eval_datasets:
# We already evaluated on this one during training
if eval_dataset in train_datasets:
continue

data = deepcopy(run.data)
data.model = model_str
data.prompts.datasets = [eval_dataset]

eval = Eval(
data=data,
source=str(run.out_dir),
out_dir=out_dir,
skip_supervised=run.supervised == "none",
)
eval.execute(highlight_color="green")