Skip to content

Commit

Permalink
s2s
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 28, 2024
1 parent f6d65d0 commit 6e568ca
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 2 deletions.
3 changes: 1 addition & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_model_and_run_name(model_name, current_name):
weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train"))
weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val"))

# train w2s with logconf, get predictions
# train w2s with logconf
print("\n\033[32m===== Training w2s model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "w2s")
train_args["run_name"] = run_name
Expand All @@ -127,7 +127,6 @@ def get_model_and_run_name(model_name, current_name):
)
# assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"])
# assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"])
w2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
w2s_ds_dict,
model_cfg,
Expand Down
186 changes: 186 additions & 0 deletions run_s2s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from pathlib import Path

import torch
from datasets import DatasetDict, load_from_disk
from simple_parsing import parse
from transformers import (
TrainingArguments,
)

from w2s.ds_registry import load_and_process_dataset
from w2s.model import ModelConfig
from w2s.sft import train
from w2s.sft_config import SFTConfig
from w2s.utils import get_config_foldername


def run_train(cfg: SFTConfig):
splits = load_and_process_dataset(
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict
)

cols = ["hard_label", "txt"]
splits = splits.select_columns(cols).rename_column("hard_label", "labels")
print(
f"Example:\n\n{splits['train'][0]['txt']}\n\nLabel: {splits['train'][0]['labels']}"
)

root = Path(cfg.results_folder) / cfg.run_name
cfg_name = get_config_foldername(vars(cfg))
train_args: dict = dict(
num_train_epochs=cfg.n_epochs,
adam_beta2=0.95,
gradient_accumulation_steps=cfg.batch_size // cfg.minibatch_size,
eval_strategy="steps",
label_names=["labels"],
load_best_model_at_end=cfg.load_best_model_at_end,
logging_steps=25,
metric_for_best_model=cfg.metric_for_best_model,
greater_is_better=cfg.greater_is_better,
per_device_train_batch_size=cfg.minibatch_size,
per_device_eval_batch_size=cfg.minibatch_size,
save_strategy="steps",
save_total_limit=cfg.save_total_limit,
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=cfg.n_warmup_steps,
weight_decay=cfg.weight_decay,
lr_scheduler_type=cfg.lr_schedule,
eval_steps=cfg.eval_every,
save_steps=cfg.save_every,
)

def get_model_and_run_name(model_name, current_name):
model_last = model_name.split("/")[-1]
model_cfg = ModelConfig(name=model_name, enable_lora=not cfg.disable_lora)
run_name = f"{current_name}-{cfg.run_name}-{cfg.dataset}-{model_last}"
return model_cfg, run_name

# train weak floor, get predictions
print("\n\033[32m===== Training weak model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.weak_model_name, "weak")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "weak")
train_args["learning_rate"] = cfg.weak_lr
weak_ds_dict = DatasetDict(
{
"train": splits["train"],
"val": splits["val"],
"test": splits["test"],
}
)
weak_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
weak_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=False,
predict_dict=weak_predict_dict,
)

# train strong ceil
print("\n\033[32m===== Training strong model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "strong")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "strong")
train_args["learning_rate"] = cfg.strong_lr
strong_ds_dict = DatasetDict(
{
"train": splits["train"],
"val": splits["val"],
"test": splits["test"],
}
)
train(
strong_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=False,
)

# load weak predictions
weak_preds_root = root / cfg_name / "weak" / "predictions"
weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train"))
weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val"))

# train w2s, get predictions
print("\n\033[32m===== Training w2s model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "w2s")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "w2s")
train_args["learning_rate"] = cfg.strong_lr
w2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
.remove_columns("labels")
.add_column("labels", weak_train_preds_ds["soft_pred"]) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", weak_val_preds_ds["soft_pred"])
), # type: ignore
"test": splits["test"],
}
)
# assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"])
# assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"])
w2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
w2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=True,
predict_dict=w2s_predict_dict,
)

prev = "w2s"

# strong-to-strong iterations
for s2s_iter in range(cfg.s2s_iters):

# load prev predictions
prev_preds_root = root / cfg_name / prev / "predictions"
prev_train_preds_ds = load_from_disk(str(prev_preds_root / "train"))
prev_val_preds_ds = load_from_disk(str(prev_preds_root / "val"))

# train s2s, get predictions
print(f"\n\033[32m===== Training s2s model iteration {s2s_iter} =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, f"s2s-{s2s_iter}")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / f"s2s-{s2s_iter}")
train_args["learning_rate"] = cfg.strong_lr
s2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
.remove_columns("labels")
.add_column("labels", prev_train_preds_ds["soft_pred"]) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", prev_val_preds_ds["soft_pred"])
), # type: ignore
"test": splits["test"],
}
)
# assert (prev_train_preds_ds["id"] == s2s_ds_dict["train"]["id"])
# assert (prev_val_preds_ds["id"] == s2s_ds_dict["val"]["id"])
s2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
s2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=True,
predict_dict=s2s_predict_dict,
)

prev = f"s2s-{s2s_iter}"

if __name__ == "__main__":
run_train(parse(SFTConfig))
2 changes: 2 additions & 0 deletions w2s/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class SFTConfig(Serializable): # TODO: what is this for??
greater_is_better: bool = field(init=False)
loss_name: str = field(init=False)

s2s_iters: int = 0

def __post_init__(self):
if "loss" in self.metric_for_best_model:
self.greater_is_better = False
Expand Down

0 comments on commit 6e568ca

Please sign in to comment.