Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix argument passthrough for sweep #266

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
29 changes: 9 additions & 20 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from datasets import get_dataset_config_info
from transformers import AutoConfig

from ..evaluation import Eval
from ..extraction import Extract
from ..files import memorably_named_dir, sweeps_dir
from ..plotting.visualize import visualize_sweep
Expand Down Expand Up @@ -134,19 +133,19 @@ def execute(self):
data = replace(
self.run_template.data, model=model, datasets=train_datasets
)
run = replace(self.run_template, data=data, out_dir=out_dir)
elicit = replace(self.run_template, data=data, out_dir=out_dir)
if var_weight is not None and neg_cov_weight is not None:
assert isinstance(run.net, EigenFitterConfig)
run.net.var_weight = var_weight
run.net.neg_cov_weight = neg_cov_weight
assert isinstance(elicit.net, EigenFitterConfig)
elicit.net.var_weight = var_weight
elicit.net.neg_cov_weight = neg_cov_weight

# Add hyperparameter values to output directory if needed
assert run.out_dir is not None
run.out_dir /= f"var_weight={var_weight:.2f}"
run.out_dir /= f"neg_cov_weight={neg_cov_weight:.2f}"
assert elicit.out_dir is not None
elicit.out_dir /= f"var_weight={var_weight:.2f}"
elicit.out_dir /= f"neg_cov_weight={neg_cov_weight:.2f}"

try:
run.execute()
elicit.execute()
except torch.linalg.LinAlgError as e:
print(colorize(f"LinAlgError: {e}", "red"))
continue
Expand All @@ -161,17 +160,7 @@ def execute(self):
if eval_dataset in train_datasets:
continue

assert run.out_dir is not None
eval = Eval(
data=replace(
run.data, model=model, datasets=(eval_dataset,)
),
source=run.out_dir,
out_dir=run.out_dir / "transfer" / eval_dataset,
num_gpus=run.num_gpus,
min_gpu_mem=run.min_gpu_mem,
skip_supervised=run.supervised == "none",
)
eval = elicit.make_eval(model, eval_dataset)
eval.execute(highlight_color="green")

if self.visualize:
Expand Down
23 changes: 22 additions & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Main training loop."""

from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Literal

Expand All @@ -11,6 +11,7 @@
from simple_parsing import subgroups
from simple_parsing.helpers.serialization import save

from ..evaluation import Eval
from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..training.supervised import train_supervised
Expand Down Expand Up @@ -48,6 +49,26 @@ def create_models_dir(self, out_dir: Path):

return reporter_dir, lr_dir

def make_eval(self, model, eval_dataset):
Copy link
Collaborator

@lauritowal lauritowal Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if that's the best place to move the code to, everything else looks good to me. 🟢

assert self.out_dir is not None
return Eval(
data=replace(
self.data,
model=model,
datasets=(eval_dataset,),
),
source=self.out_dir,
out_dir=self.out_dir / "transfer" / eval_dataset,
num_gpus=self.num_gpus,
min_gpu_mem=self.min_gpu_mem,
skip_supervised=self.supervised == "none",
prompt_indices=self.prompt_indices,
concatenated_layer_offset=self.concatenated_layer_offset,
# datasets isn't needed because it's immediately overwritten
debug=self.debug,
disable_cache=self.disable_cache,
)

def apply_to_layer(
self,
layer: int,
Expand Down