Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix argument passthrough for sweep #266

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
move to make_eval
  • Loading branch information
derpyplops committed Jul 12, 2023
commit 0acd41c457c3d5fa8e31b93949ef645a55a0d455
20 changes: 1 addition & 19 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,8 @@ def execute(self):

assert run.out_dir is not None
# TODO we should fix this so that this isn't needed
eval = Eval(
extract=replace(
run.extract,
model=model,
datasets=(eval_dataset,),
),
source=run.out_dir,
out_dir=run.out_dir / "transfer" / eval_dataset,
num_gpus=run.num_gpus,
min_gpu_mem=run.min_gpu_mem,
skip_supervised=run.supervised == "none",
prompt_indices=run.prompt_indices,
concatenated_layer_offset=run.concatenated_layer_offset,
# datasets=run.datasets,
# this isn't needed because it's
# immediately overwritten
debug=run.debug,
disable_cache=run.disable_cache,
)

eval = run.make_eval(model, eval_dataset)
eval.execute(highlight_color="green")

if self.visualize:
Expand Down
24 changes: 23 additions & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Main training loop."""

from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Literal

Expand All @@ -11,6 +11,7 @@
from simple_parsing import subgroups
from simple_parsing.helpers.serialization import save

from ..evaluation import Eval
from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..training.supervised import train_supervised
Expand Down Expand Up @@ -48,6 +49,27 @@ def create_models_dir(self, out_dir: Path):

return reporter_dir, lr_dir

def make_eval(self, model, eval_dataset):
Copy link
Collaborator

@lauritowal lauritowal Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if that's the best place to move the code to, everything else looks good to me. 🟢

return Eval(
extract=replace(
self.extract,
model=model,
datasets=(eval_dataset,),
),
source=self.out_dir,
out_dir=self.out_dir / "transfer" / eval_dataset,
num_gpus=self.num_gpus,
min_gpu_mem=self.min_gpu_mem,
skip_supervised=self.supervised == "none",
prompt_indices=self.prompt_indices,
concatenated_layer_offset=self.concatenated_layer_offset,
# datasets=run.datasets isn't needed because it's
# immediately overwritten
debug=self.debug,
disable_cache=self.disable_cache,
)


def apply_to_layer(
self,
layer: int,
Expand Down