Skip to content

Commit

Permalink
s2s to default run
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 28, 2024
1 parent 6e568ca commit 8181887
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 187 deletions.
48 changes: 47 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_model_and_run_name(model_name, current_name):
weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train"))
weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val"))

# train w2s with logconf
# train w2s, get predictions
print("\n\033[32m===== Training w2s model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "w2s")
train_args["run_name"] = run_name
Expand All @@ -127,14 +127,60 @@ def get_model_and_run_name(model_name, current_name):
)
# assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"])
# assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"])
w2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
w2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=True,
predict_dict=w2s_predict_dict,
)

prev = "w2s"

# strong-to-strong iterations
for s2s_iter in range(cfg.s2s_iters):

# load prev predictions
prev_preds_root = root / cfg_name / prev / "predictions"
prev_train_preds_ds = load_from_disk(str(prev_preds_root / "train"))
prev_val_preds_ds = load_from_disk(str(prev_preds_root / "val"))

# train s2s, get predictions
print(f"\n\033[32m===== Training s2s model iteration {s2s_iter} =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, f"s2s-{s2s_iter}")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / f"s2s-{s2s_iter}")
train_args["learning_rate"] = cfg.strong_lr
s2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
.remove_columns("labels")
.add_column("labels", prev_train_preds_ds["soft_pred"]) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", prev_val_preds_ds["soft_pred"])
), # type: ignore
"test": splits["test"],
}
)
# assert (prev_train_preds_ds["id"] == s2s_ds_dict["train"]["id"])
# assert (prev_val_preds_ds["id"] == s2s_ds_dict["val"]["id"])
s2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
s2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=True,
predict_dict=s2s_predict_dict,
)

prev = f"s2s-{s2s_iter}"

if __name__ == "__main__":
run_train(parse(SFTConfig))
186 changes: 0 additions & 186 deletions run_s2s.py

This file was deleted.

0 comments on commit 8181887

Please sign in to comment.