Skip to content

Commit

Permalink
WIP add multiprobe training
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Jul 14, 2023
1 parent a50fe57 commit 52b1394
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def make_eval(self, model, eval_dataset):
)

# Create a separate function to handle the reporter training.
def train_reporter(self, device, layer, out_dir) -> ReporterTrainResult:
train_dict = self.prepare_data(device, layer, "train")

def train_and_save_reporter(
self, device, layer, out_dir, train_dict
) -> ReporterTrainResult:
(first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove?
(_, v, k, d) = first_train_h.shape
if not all(other_h.shape[-1] == d for other_h, _, _ in rest):
Expand Down Expand Up @@ -128,6 +128,7 @@ def train_reporter(self, device, layer, out_dir) -> ReporterTrainResult:
raise ValueError(f"Unknown reporter config type: {type(self.net)}")

# Save reporter checkpoint to disk
# TODO have to change this
torch.save(reporter, out_dir / f"layer_{layer}.pt")

return ReporterTrainResult(reporter, train_loss)
Expand Down Expand Up @@ -166,13 +167,36 @@ def apply_to_layer(

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))

reporter_train_result = self.train_reporter(device, layer, reporter_dir)
probe_per_prompt = True
if probe_per_prompt:
train_dicts = [
{
ds_name: (
train_h[:, i : i + 1, ...],
train_gt,
lm_preds[:, i : i + 1, ...],
)
}
for ds_name, (train_h, _, lm_preds) in train_dict.items()
for i in range(v) # v is number of variants
]

[
self.train_and_save_reporter(device, layer, reporter_dir, train_dict)
for train_dict in train_dicts
]
else:
reporter_train_result = self.train_and_save_reporter(
device, layer, reporter_dir, train_dict
)

reporter = reporter_train_result.reporter
train_loss = reporter_train_result.train_loss

lr_models = self.train_lr_model(train_dict, device, layer, lr_dir)

row_bufs = defaultdict(list)

for ds_name in val_dict:
val_h, val_gt, val_lm_preds = val_dict[ds_name]
train_h, train_gt, train_lm_preds = train_dict[ds_name]
Expand Down

0 comments on commit 52b1394

Please sign in to comment.