Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train probe per prompt #271

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove redundant method
  • Loading branch information
derpyplops committed Jul 20, 2023
commit 327d1eb17e112c79004f6041a5829d7203801628
33 changes: 2 additions & 31 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from collections import defaultdict
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Literal

import pandas as pd
import torch
from einops import rearrange, repeat
from simple_parsing import subgroups
from simple_parsing.helpers.serialization import save

from ..evaluation import Eval
from ..metrics import evaluate_preds, to_one_hot
Expand Down Expand Up @@ -124,20 +122,6 @@ class Elicit(Run):
cross-validation. Defaults to "single", which means to train a single classifier
on the training data. "cv" means to use cross-validation."""

def create_models_dir(self, out_dir: Path):
lr_dir = None
lr_dir = out_dir / "lr_models"
reporter_dir = out_dir / "reporters"

lr_dir.mkdir(parents=True, exist_ok=True)
reporter_dir.mkdir(parents=True, exist_ok=True)

# Save the reporter config separately in the reporter directory
# for convenient loading of reporters later.
save(self.net, reporter_dir / "cfg.yaml", save_dc_types=True)

return reporter_dir, lr_dir

def make_eval(self, model, eval_dataset):
assert self.out_dir is not None
return Eval(
Expand Down Expand Up @@ -254,9 +238,6 @@ def apply_to_layer(
(first_train_h, train_gt, _), *rest = train_dict.values()
(_, v, k, d) = first_train_h.shape

# TODO is this even needed
# reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))

if probe_per_prompt:
prompt_indices = self.prompt_indices if self.prompt_indices else range(v)
prompt_train_dicts = [
Expand Down Expand Up @@ -297,25 +278,15 @@ def apply_to_layer(
# TODO fix lr_models

else:
prompt_train_dict = {
ds_name: (
train_h[:, self.prompt_indices, ...],
train_gt,
lm_preds[:, self.prompt_indices, ...]
if lm_preds is not None
else None,
)
for ds_name, (train_h, _, lm_preds) in train_dict.items()
}
reporter_train_result = self.train_and_save_reporter(
device, layer, self.out_dir / "reporters", prompt_train_dict
device, layer, self.out_dir / "reporters", train_dict
)

maybe_multi_reporter = reporter_train_result.reporter
train_loss = reporter_train_result.train_loss

lr_models = self.train_lr_model(
prompt_train_dict, device, layer, self.out_dir / "lr_models"
train_dict, device, layer, self.out_dir / "lr_models"
)

return evaluate_and_save(
Expand Down