-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f6d65d0
commit 6e568ca
Showing
3 changed files
with
189 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
from datasets import DatasetDict, load_from_disk | ||
from simple_parsing import parse | ||
from transformers import ( | ||
TrainingArguments, | ||
) | ||
|
||
from w2s.ds_registry import load_and_process_dataset | ||
from w2s.model import ModelConfig | ||
from w2s.sft import train | ||
from w2s.sft_config import SFTConfig | ||
from w2s.utils import get_config_foldername | ||
|
||
|
||
def run_train(cfg: SFTConfig): | ||
splits = load_and_process_dataset( | ||
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict | ||
) | ||
|
||
cols = ["hard_label", "txt"] | ||
splits = splits.select_columns(cols).rename_column("hard_label", "labels") | ||
print( | ||
f"Example:\n\n{splits['train'][0]['txt']}\n\nLabel: {splits['train'][0]['labels']}" | ||
) | ||
|
||
root = Path(cfg.results_folder) / cfg.run_name | ||
cfg_name = get_config_foldername(vars(cfg)) | ||
train_args: dict = dict( | ||
num_train_epochs=cfg.n_epochs, | ||
adam_beta2=0.95, | ||
gradient_accumulation_steps=cfg.batch_size // cfg.minibatch_size, | ||
eval_strategy="steps", | ||
label_names=["labels"], | ||
load_best_model_at_end=cfg.load_best_model_at_end, | ||
logging_steps=25, | ||
metric_for_best_model=cfg.metric_for_best_model, | ||
greater_is_better=cfg.greater_is_better, | ||
per_device_train_batch_size=cfg.minibatch_size, | ||
per_device_eval_batch_size=cfg.minibatch_size, | ||
save_strategy="steps", | ||
save_total_limit=cfg.save_total_limit, | ||
tf32=True, # Use Tensor Cores even for fp32 matmuls | ||
warmup_steps=cfg.n_warmup_steps, | ||
weight_decay=cfg.weight_decay, | ||
lr_scheduler_type=cfg.lr_schedule, | ||
eval_steps=cfg.eval_every, | ||
save_steps=cfg.save_every, | ||
) | ||
|
||
def get_model_and_run_name(model_name, current_name): | ||
model_last = model_name.split("/")[-1] | ||
model_cfg = ModelConfig(name=model_name, enable_lora=not cfg.disable_lora) | ||
run_name = f"{current_name}-{cfg.run_name}-{cfg.dataset}-{model_last}" | ||
return model_cfg, run_name | ||
|
||
# train weak floor, get predictions | ||
print("\n\033[32m===== Training weak model =====\033[0m") | ||
model_cfg, run_name = get_model_and_run_name(cfg.weak_model_name, "weak") | ||
train_args["run_name"] = run_name | ||
train_args["output_dir"] = str(root / cfg_name / "weak") | ||
train_args["learning_rate"] = cfg.weak_lr | ||
weak_ds_dict = DatasetDict( | ||
{ | ||
"train": splits["train"], | ||
"val": splits["val"], | ||
"test": splits["test"], | ||
} | ||
) | ||
weak_predict_dict = {"train": splits["train"], "val": splits["val"]} | ||
train( | ||
weak_ds_dict, | ||
model_cfg, | ||
TrainingArguments(**train_args), | ||
cfg.to_dict(), | ||
transfer=False, | ||
predict_dict=weak_predict_dict, | ||
) | ||
|
||
# train strong ceil | ||
print("\n\033[32m===== Training strong model =====\033[0m") | ||
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "strong") | ||
train_args["run_name"] = run_name | ||
train_args["output_dir"] = str(root / cfg_name / "strong") | ||
train_args["learning_rate"] = cfg.strong_lr | ||
strong_ds_dict = DatasetDict( | ||
{ | ||
"train": splits["train"], | ||
"val": splits["val"], | ||
"test": splits["test"], | ||
} | ||
) | ||
train( | ||
strong_ds_dict, | ||
model_cfg, | ||
TrainingArguments(**train_args), | ||
cfg.to_dict(), | ||
transfer=False, | ||
) | ||
|
||
# load weak predictions | ||
weak_preds_root = root / cfg_name / "weak" / "predictions" | ||
weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train")) | ||
weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val")) | ||
|
||
# train w2s, get predictions | ||
print("\n\033[32m===== Training w2s model =====\033[0m") | ||
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "w2s") | ||
train_args["run_name"] = run_name | ||
train_args["output_dir"] = str(root / cfg_name / "w2s") | ||
train_args["learning_rate"] = cfg.strong_lr | ||
w2s_ds_dict = DatasetDict( | ||
{ | ||
"train": ( | ||
splits["train"] | ||
.remove_columns("labels") | ||
.add_column("labels", weak_train_preds_ds["soft_pred"]) # type: ignore | ||
), | ||
"val": ( | ||
splits["val"] | ||
.remove_columns("labels") | ||
.add_column("labels", weak_val_preds_ds["soft_pred"]) | ||
), # type: ignore | ||
"test": splits["test"], | ||
} | ||
) | ||
# assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"]) | ||
# assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"]) | ||
w2s_predict_dict = {"train": splits["train"], "val": splits["val"]} | ||
train( | ||
w2s_ds_dict, | ||
model_cfg, | ||
TrainingArguments(**train_args), | ||
cfg.to_dict(), | ||
transfer=True, | ||
predict_dict=w2s_predict_dict, | ||
) | ||
|
||
prev = "w2s" | ||
|
||
# strong-to-strong iterations | ||
for s2s_iter in range(cfg.s2s_iters): | ||
|
||
# load prev predictions | ||
prev_preds_root = root / cfg_name / prev / "predictions" | ||
prev_train_preds_ds = load_from_disk(str(prev_preds_root / "train")) | ||
prev_val_preds_ds = load_from_disk(str(prev_preds_root / "val")) | ||
|
||
# train s2s, get predictions | ||
print(f"\n\033[32m===== Training s2s model iteration {s2s_iter} =====\033[0m") | ||
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, f"s2s-{s2s_iter}") | ||
train_args["run_name"] = run_name | ||
train_args["output_dir"] = str(root / cfg_name / f"s2s-{s2s_iter}") | ||
train_args["learning_rate"] = cfg.strong_lr | ||
s2s_ds_dict = DatasetDict( | ||
{ | ||
"train": ( | ||
splits["train"] | ||
.remove_columns("labels") | ||
.add_column("labels", prev_train_preds_ds["soft_pred"]) # type: ignore | ||
), | ||
"val": ( | ||
splits["val"] | ||
.remove_columns("labels") | ||
.add_column("labels", prev_val_preds_ds["soft_pred"]) | ||
), # type: ignore | ||
"test": splits["test"], | ||
} | ||
) | ||
# assert (prev_train_preds_ds["id"] == s2s_ds_dict["train"]["id"]) | ||
# assert (prev_val_preds_ds["id"] == s2s_ds_dict["val"]["id"]) | ||
s2s_predict_dict = {"train": splits["train"], "val": splits["val"]} | ||
train( | ||
s2s_ds_dict, | ||
model_cfg, | ||
TrainingArguments(**train_args), | ||
cfg.to_dict(), | ||
transfer=True, | ||
predict_dict=s2s_predict_dict, | ||
) | ||
|
||
prev = f"s2s-{s2s_iter}" | ||
|
||
if __name__ == "__main__": | ||
run_train(parse(SFTConfig)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters