Skip to content

Commit

Permalink
s2s initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 15, 2024
1 parent 1921d47 commit add1b42
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 41 deletions.
140 changes: 100 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

import torch
from datasets import Value
from datasets import Value, disable_caching
from peft import (
AutoPeftModelForSequenceClassification,
LoraConfig,
Expand All @@ -20,9 +20,10 @@
import wandb
from w2s.ds_registry import load_and_process_dataset
from w2s.knn import gather_hiddens, topofilter
from w2s.loss import log_confidence_loss
from w2s.roc_auc import roc_auc

disable_caching()


@dataclass
class TrainConfig(Serializable):
Expand All @@ -44,15 +45,20 @@ class TrainConfig(Serializable):
run_name: str = ""
"""Name of the run."""

s2s_iter: int = 0
"""Number of strong-to-strong iterations to perform."""


class DistillationTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
labels = inputs.pop("labels").float()

outputs = model(**inputs)
frac = self.state.global_step / self.state.max_steps
loss = log_confidence_loss(outputs.logits, labels, frac)
# frac = self.state.global_step / self.state.max_steps
# loss = log_confidence_loss(outputs.logits, labels, frac)

labels = torch.stack([1.0 - labels, labels], dim=-1)
loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
return (loss, outputs) if return_outputs else loss


Expand Down Expand Up @@ -97,6 +103,14 @@ def train(cfg: TrainConfig):
cfg.dataset, split_sizes=dict(train=20_000, test=1_000)
)

def init_strong_model():
model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
model.config.pad_token_id = strong_tokenizer.pad_token_id
model.score.weight.data *= 0.01
return model

def weak_processor(examples):
out = weak_tokenizer(examples["txt"], truncation=True)
out["labels"] = examples["hard_label"]
Expand All @@ -112,6 +126,7 @@ def compute_metrics(eval_pred):
cols = ["hard_label", "txt"]
test = splits["test"].select_columns(cols)
train = splits["train"].select_columns(cols)
print(f"Train example:\n\n{train[0]['txt']}\n\nLabel: {train[0]['hard_label']}")

weak_test = test.map(weak_processor, batched=True).cast_column(
"labels", Value("int64")
Expand Down Expand Up @@ -187,8 +202,8 @@ def compute_metrics(eval_pred):
test_logits = trainer.predict(weak_test).predictions

# Convert to probabilities, then keep only the positive probs
_, train_probs = torch.from_numpy(train_logits).softmax(-1).unbind(-1)
_, test_probs = torch.from_numpy(test_logits).softmax(-1).unbind(-1)
train_probs = torch.from_numpy(train_logits).softmax(-1)[:, 1]
test_probs = torch.from_numpy(test_logits).softmax(-1)[:, 1]

label_dir.mkdir(parents=True, exist_ok=True)
torch.save(train_probs, label_dir / "train.pt")
Expand All @@ -213,13 +228,6 @@ def strong_processor(examples):
print(f"Strong ceiling model already exists at {strong_ckpt}")
else:
print("\n\033[32m===== Training strong ceiling model =====\033[0m")
strong_model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
# HuggingFace init for the head is too large
strong_model.score.weight.data *= 0.01
strong_model.config.pad_token_id = strong_tokenizer.pad_token_id

training_args.output_dir = str(root / "ceil")
training_args.run_name = cfg.dataset + "/ceil" + cfg.run_name

Expand All @@ -228,40 +236,18 @@ def strong_processor(examples):
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
model=get_peft_model(strong_model, lora_cfg),
model=get_peft_model(init_strong_model(), lora_cfg),
tokenizer=strong_tokenizer,
train_dataset=strong_train,
)
trainer.train()
wandb.finish()
move_best_ckpt(trainer)

print("\n\033[32m===== Training w2s model =====\033[0m")
strong_model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
# HuggingFace init for the head is too large
strong_model.score.weight.data *= 0.01
strong_model.config.pad_token_id = strong_tokenizer.pad_token_id

# Weak to strong generalization
acts_path = root / "ceil/acts.pt"
if acts_path.exists():
print(f"Loading strong activations from {acts_path}")
train_acts = torch.load(acts_path, map_location=strong_model.device)
else:
print("Gathering strong activations")
train_acts = gather_hiddens(strong_model, strong_train)
torch.save(train_acts, acts_path)

w2s_train = strong_train.remove_columns("labels")
w2s_train = w2s_train.add_column("labels", train_probs.numpy())

if cfg.contamination > 0.0:
y = train_probs.to(train_acts.device)
indices = topofilter(train_acts, y, cfg.contamination, k=cfg.outlier_k)
w2s_train = w2s_train.select(indices)

# Check gt metrics every 100 steps during w2s training.
# We can overfit to the weak labels before a single epoch.
training_args.evaluation_strategy = "steps"
Expand All @@ -272,17 +258,91 @@ def strong_processor(examples):
training_args.output_dir = str(root / "w2s") + cfg.run_name
training_args.run_name = cfg.dataset + "/w2s" + cfg.run_name

should_train = True
w2s_ckpt = root / "ceil" / "best-ckpt"
if w2s_ckpt.exists():
print(f"W2S model already exists at {strong_ckpt}")

w2s_model = AutoPeftModelForSequenceClassification.from_pretrained(
w2s_ckpt, torch_dtype="auto", device_map={"": "cuda"}
)
should_train = False
else:
print("\n\033[32m===== Training w2s model =====\033[0m")

strong_model = init_strong_model()
if cfg.contamination > 0.0:
acts_path = root / "ceil/acts.pt"
if acts_path.exists():
print(f"Loading strong activations from {acts_path}")
train_acts = torch.load(acts_path, map_location=strong_model.device)
else:
print("Gathering strong activations")
train_acts = gather_hiddens(strong_model, strong_train)
torch.save(train_acts, acts_path)

y = train_probs.to(train_acts.device)
indices = topofilter(train_acts, y, cfg.contamination, k=cfg.outlier_k)
w2s_train = w2s_train.select(indices)

w2s_model = get_peft_model(strong_model, lora_cfg)

w2s_model.config.pad_token_id = strong_tokenizer.pad_token_id
trainer = DistillationTrainer(
args=training_args,
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
model=get_peft_model(strong_model, lora_cfg),
model=w2s_model,
tokenizer=strong_tokenizer,
train_dataset=w2s_train,
)
trainer.train()
wandb.finish()
if should_train:
trainer.train()
wandb.finish()
move_best_ckpt(trainer)

# Save memory
del w2s_model
preds_path = root / "ceil/preds.pt"

# Strong to strong generalization
for i in range(cfg.s2s_iter):
print(f"\n\033[32m===== Self-distillation iter {i + 1} =====\033[0m")

# Gather strong activations
if preds_path.exists():
print(f"Loading strong preds from {preds_path}")
train_probs = torch.load(preds_path)
else:
train_logits = trainer.predict(w2s_train).predictions
train_probs = torch.from_numpy(train_logits).softmax(-1)[:, 1]

torch.save(train_probs, preds_path)

del trainer

w2s_train = w2s_train.remove_columns("labels").add_column(
"labels", train_probs.numpy()
)

name = f"s2s_iter{i + 1}" + cfg.run_name
training_args.output_dir = str(root / name)
training_args.run_name = cfg.dataset + "/" + name

trainer = DistillationTrainer(
args=training_args,
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
model=get_peft_model(init_strong_model(), lora_cfg),
tokenizer=strong_tokenizer,
train_dataset=w2s_train,
)
trainer.train()
wandb.finish()

preds_path = root / name / "preds.pt"


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion w2s/ds_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def format_dream(ex, rng):

ans = rng.choice(distractors)

txt = f"{ex['dialogue']}\n\nQ: {ex['question']} A: {ans}"
txt = f"{'\n'.join(ex['dialogue'])}\n\nQ: {ex['question']} A: {ans}"
return dict(txt=txt, hard_label=hard_label)


Expand Down

0 comments on commit add1b42

Please sign in to comment.