Skip to content

Commit

Permalink
Merge pull request #2 from EleutherAI/repro
Browse files Browse the repository at this point in the history
Blog post version
  • Loading branch information
AdamScherlis committed Jun 15, 2024
2 parents 1921d47 + 209d5d0 commit cf99b7d
Show file tree
Hide file tree
Showing 19 changed files with 1,583 additions and 500 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
results/
results*/
wandb/
plots/
scratch/
venv*/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ repos:
rev: 'v0.0.262'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
args: [--fix, --exit-non-zero-on-fix, --line-length=100]
8 changes: 0 additions & 8 deletions .vscode/settings.json

This file was deleted.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ dependencies = [
"peft",
"scipy",
"simple-parsing",
"fire ~= 0.4",
"pynvml ~= 11.5",
"scikit-learn ~= 1.3.2",
# 4.0 introduced the breaking change of using return_dict=True by default
"transformers>=4.0.0",
"wandb",
Expand Down
202 changes: 202 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from pathlib import Path
import os

import torch
from datasets import DatasetDict, load_from_disk
from simple_parsing import parse
from transformers import (
TrainingArguments,
)

from w2s.ds_registry import load_and_process_dataset
from w2s.model import ModelConfig
from w2s.sft import train
from w2s.sft_config import SFTConfig
from w2s.utils import get_config_foldername


def run_train(cfg: SFTConfig):
print(f"Loading and processing dataset {cfg.dataset}")
splits = load_and_process_dataset(
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict
)

train_halves = splits["train"].train_test_split(test_size=0.5, seed=42)
splits["weak_train"] = train_halves["train"]
splits["strong_train"] = train_halves["test"]

cols = ["hard_label", "txt"]
splits = splits.select_columns(cols).rename_column("hard_label", "labels")
for split in splits:
splits[split] = splits[split].add_column("gt_labels", splits[split]["labels"])

print(
f"Example:\n\n{splits['strong_train'][0]['txt']}\n\nLabel: {splits['strong_train'][0]['labels']}"
)

root = Path(cfg.results_folder) / cfg.run_name
shared_root = Path(cfg.results_folder) / cfg.shared_folder
cfg_name = cfg.dataset
train_args: dict = dict(
num_train_epochs=cfg.n_epochs,
adam_beta2=0.95,
gradient_accumulation_steps=cfg.batch_size // cfg.minibatch_size,
eval_strategy="steps",
label_names=["labels"],
load_best_model_at_end=cfg.load_best_model_at_end,
logging_steps=25,
metric_for_best_model=cfg.metric_for_best_model,
greater_is_better=cfg.greater_is_better,
per_device_train_batch_size=cfg.minibatch_size,
per_device_eval_batch_size=cfg.minibatch_size,
save_strategy="steps",
save_total_limit=cfg.save_total_limit,
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=cfg.n_warmup_steps,
weight_decay=cfg.weight_decay,
lr_scheduler_type=cfg.lr_schedule,
eval_steps=cfg.eval_every,
save_steps=cfg.save_every,
)

def get_model_and_run_name(model_name, current_name):
model_last = model_name.split("/")[-1]
model_cfg = ModelConfig(name=model_name, enable_lora=not cfg.disable_lora)
run_name = f"{current_name}-{cfg.run_name}-{cfg.dataset}-{model_last}"
return model_cfg, run_name

# train weak floor, get predictions
print("\n\033[32m===== Training weak model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.weak_model_name, "weak")
train_args["run_name"] = run_name
train_args["output_dir"] = str(shared_root / cfg_name / "weak")
train_args["learning_rate"] = cfg.weak_lr
weak_ds_dict = DatasetDict(
{
"train": splits["weak_train"],
"val": splits["val"],
"test": splits["test"],
}
)
weak_predict_dict = {"train": splits["strong_train"], "val": splits["val"]}
train(
weak_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=False,
predict_dict=weak_predict_dict,
)

# train strong ceil
print("\n\033[32m===== Training strong model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "strong")
train_args["run_name"] = run_name
train_args["output_dir"] = str(shared_root / cfg_name / "strong")
train_args["learning_rate"] = cfg.strong_lr
strong_ds_dict = DatasetDict(
{
"train": splits["strong_train"],
"val": splits["val"],
"test": splits["test"],
}
)
train(
strong_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=False,
)

# load weak predictions
weak_preds_root = shared_root / cfg_name / "weak" / "predictions"
print(f"Loading weak predictions from {weak_preds_root}")
weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train"))
weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val"))

# train w2s, get predictions
print("\n\033[32m===== Training w2s model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "w2s")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "w2s")
train_args["learning_rate"] = cfg.strong_lr
w2s_ds_dict = DatasetDict(
{
"train": (
splits["strong_train"]
.remove_columns("labels")
.add_column("labels", weak_train_preds_ds["soft_pred"]) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", weak_val_preds_ds["soft_pred"])
), # type: ignore
"test": splits["test"],
}
)
# assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"])
# assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"])
w2s_predict_dict = {"train": splits["strong_train"], "val": splits["val"]}
train(
w2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=True,
predict_dict=w2s_predict_dict,
save_activations=True,
acts_dir=shared_root / cfg_name / "w2s" / "activations",
)

prev = "w2s"

# strong-to-strong iterations
for s2s_iter in range(cfg.s2s_iters):

# load prev predictions
prev_preds_root = root / cfg_name / prev / "predictions"
print(f"Loading {prev} predictions from {prev_preds_root}")
prev_train_preds_ds = load_from_disk(str(prev_preds_root / "train"))
prev_val_preds_ds = load_from_disk(str(prev_preds_root / "val"))

# train s2s, get predictions
print(f"\n\033[32m===== Training s2s model iteration {s2s_iter} =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, f"s2s-{s2s_iter}")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / f"s2s-{s2s_iter}")
train_args["learning_rate"] = cfg.strong_lr
s2s_ds_dict = DatasetDict(
{
"train": (
splits["strong_train"]
.remove_columns("labels")
.add_column("labels", prev_train_preds_ds["soft_pred"]) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", prev_val_preds_ds["soft_pred"])
), # type: ignore
"test": splits["test"],
}
)
# assert (prev_train_preds_ds["id"] == s2s_ds_dict["train"]["id"])
# assert (prev_val_preds_ds["id"] == s2s_ds_dict["val"]["id"])
s2s_predict_dict = {"train": splits["strong_train"], "val": splits["val"]}
train(
s2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=True,
predict_dict=s2s_predict_dict,
acts_dir=shared_root / cfg_name / f"s2s-{s2s_iter}" / "activations",
)

prev = f"s2s-{s2s_iter}"

if __name__ == "__main__":
os.environ["HF_DATASETS_CACHE"] = "/home/adam/hf_cache"
run_train(parse(SFTConfig))
Loading

0 comments on commit cf99b7d

Please sign in to comment.