Skip to content

Commit

Permalink
Now training strong ceiling models
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 10, 2024
1 parent 7048f68 commit 77f9fce
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 48 deletions.
4 changes: 4 additions & 0 deletions w2s/ds_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def format_ethics_utilitarianism(ex, rng):
)


def format_lichess(ex, rng):
return dict(txt=ex["txt"], hard_label=ex["hard_label"])


def format_mc_taco(ex, rng):
template = "{sentence}\n\nGiven the above, {question} Is the answer {answer}?"
return dict(txt=template.format(**ex), hard_label=ex["label"])
Expand Down
128 changes: 80 additions & 48 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import numpy as np
import torch
from peft import LoraConfig, get_peft_model
from peft import (
AutoPeftModelForSequenceClassification,
LoraConfig,
get_peft_model,
)
from simple_parsing import Serializable, field, parse
from sklearn.metrics import roc_auc_score
from transformers import (
Expand Down Expand Up @@ -38,8 +42,36 @@ def compute_loss(self, model, inputs, return_outputs=False):
return (loss, outputs) if return_outputs else loss


# Works for both Llama and Qwen architectures
LORA_MODULES = [
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
]


def move_best_ckpt(trainer: Trainer):
path = trainer.state.best_model_checkpoint
perf = trainer.state.best_metric
assert path is not None, "No best checkpoint found"
assert perf is not None, "No best metric"

src = Path(path)
dest = src.parent / "best-ckpt"
src.rename(dest)
print(f"Best model (loss {perf:.3f}) saved at: {dest}")


def main():
cfg = parse(TrainConfig)
lora_cfg = LoraConfig(target_modules=LORA_MODULES)

STRONG_NAME = "meta-llama/Meta-Llama-3-8B"
strong_tokenizer = AutoTokenizer.from_pretrained(STRONG_NAME)
weak_tokenizer = AutoTokenizer.from_pretrained(cfg.weak_name)

def weak_processor(examples):
Expand All @@ -57,40 +89,50 @@ def compute_metrics(eval_pred):
splits = load_and_process_dataset(
cfg.dataset, split_sizes=dict(train=20_000, test=1_000)
)
weak_test = splits["test"].map(weak_processor, batched=True)
weak_train = splits["train"].map(weak_processor, batched=True)
test = splits["test"].select_columns(["hard_label", "txt"])
train = splits["train"].select_columns(["hard_label", "txt"])
weak_test = test.map(weak_processor, batched=True)
weak_train = train.map(weak_processor, batched=True)

root = Path("results") / cfg.dataset
training_args = TrainingArguments(
str(root / "weak"),
str(root / "floor"),
adam_beta2=0.95,
evaluation_strategy="epoch",
label_names=["labels"],
learning_rate=2e-5,
load_best_model_at_end=True,
logging_steps=50,
metric_for_best_model="auroc",
per_device_train_batch_size=16,
save_strategy="epoch",
save_total_limit=1,
tf32=True, # Use Tensor Cores even for fp32 matmuls
weight_decay=0.1,
weight_decay=0.01,
)

# Gather weak labels
label_dir = root / "weak/preds"
label_dir = root / "floor/preds"
if label_dir.exists():
print(f"Loading weak labels from {label_dir}")
train_probs = np.load(label_dir / "train.npy")
test_probs = np.load(label_dir / "test.npy")
else:
should_train = True
weak_path = root / "weak/best-ckpt"
weak_path = root / "floor/best-ckpt"
if not weak_path.exists():
weak_path = cfg.weak_name
weak_model = AutoModelForSequenceClassification.from_pretrained(
cfg.weak_name, torch_dtype="auto"
)
weak_model = get_peft_model(weak_model, lora_cfg)
else:
print("Loading weak model from:", weak_path)
should_train = False

weak_model = AutoModelForSequenceClassification.from_pretrained(
weak_path, torch_dtype="auto"
)
weak_model = AutoPeftModelForSequenceClassification.from_pretrained(
weak_path, torch_dtype="auto"
)

# Make sure the pad token is set
weak_model.config.pad_token_id = (
weak_tokenizer.pad_token_id
Expand All @@ -106,16 +148,9 @@ def compute_metrics(eval_pred):
train_dataset=weak_train,
)
if should_train:
print("Training weak model")
print("\033[32m===== Training weak model =====\033[0m")
trainer.train()

path = trainer.state.best_model_checkpoint
assert path, "No best checkpoint found"

src = Path(path)
dest = src.parent / "best-ckpt"
src.rename(dest)
print(f"Best model saved at: {dest}")
move_best_ckpt(trainer)

print("Gathering weak labels")
train_logits = trainer.predict(weak_train).predictions
Expand All @@ -130,32 +165,38 @@ def compute_metrics(eval_pred):
np.save(label_dir / "train.npy", train_probs)
np.save(label_dir / "test.npy", test_probs)

# Train the strong model
STRONG_NAME = "meta-llama/Meta-Llama-3-8B"
print("\033[32m===== Training strong ceiling model =====\033[0m")
strong_model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
strong_tokenizer = AutoTokenizer.from_pretrained(STRONG_NAME)
strong_model.config.pad_token_id = (
strong_tokenizer.pad_token_id
) = strong_tokenizer.eos_token_id

# Make sure that we use CrossEntropyLoss
strong_model.config.problem_type = "single_label_classification"

def strong_processor(examples):
return strong_tokenizer(examples["txt"], truncation=True)

strong_train = (
splits["train"].remove_columns("label").add_column("labels", train_probs)
)
strong_test = (
splits["test"].remove_columns("label").add_column("labels", test_probs)
strong_train = splits["train"].map(strong_processor, batched=True)
ceil_test = splits["train"].map(strong_processor, batched=True)
ceil_test = splits["test"].rename_column("hard_label", "labels")

training_args.output_dir = str(root / "ceil")
trainer = Trainer(
args=training_args,
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
model=get_peft_model(strong_model, lora_cfg),
tokenizer=strong_tokenizer,
train_dataset=strong_train.rename_column("hard_label", "labels"),
)
trainer.train()
move_best_ckpt(trainer)

strong_train = strong_train.map(strong_processor, batched=True)
strong_test = strong_test.map(strong_processor, batched=True)
# Make sure that we use CrossEntropyLoss
strong_model.config.problem_type = "single_label_classification"

# Weak to strong generalization
acts_path = root / "strong/acts.pt"
if acts_path.exists():
print(f"Loading strong activations from {acts_path}")
Expand All @@ -171,28 +212,19 @@ def strong_processor(examples):
strong_train = strong_train.select(top.tolist())

training_args.label_names = ["labels"]
training_args.output_dir = str(root / "strong")

lora_cfg = LoraConfig(
target_modules=[
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
]
)
training_args.output_dir = str(root / "w2s")

w2s_train = strong_train.add_column("labels", train_probs)
trainer = DistillationTrainer(
args=training_args,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=strong_test,
eval_dataset=ceil_test,
model=get_peft_model(strong_model, lora_cfg),
tokenizer=strong_tokenizer,
train_dataset=strong_train,
train_dataset=w2s_train,
)
trainer.train()
move_best_ckpt(trainer)


if __name__ == "__main__":
Expand Down

0 comments on commit 77f9fce

Please sign in to comment.