Skip to content

Commit

Permalink
Log confidence loss
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 11, 2024
1 parent 1d0ee29 commit f57cb4c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 51 deletions.
24 changes: 24 additions & 0 deletions w2s/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch


def log_confidence_loss(
logits,
labels,
step_frac: float,
warmup_frac: float = 0.1,
aux_coef: float = 0.5,
):
logits = logits.float()
labels = labels.float()

coef = aux_coef * min(1.0, step_frac / warmup_frac)
preds = torch.softmax(logits, dim=-1)

threshold = torch.quantile(preds[:, 0], labels.mean())
strong_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
)
labels_binary = torch.stack([1.0 - labels, labels], dim=1)
target = labels_binary * (1 - coef) + strong_preds.detach() * coef
return torch.nn.functional.cross_entropy(logits, target)
115 changes: 64 additions & 51 deletions w2s/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
from peft import (
AutoPeftModelForSequenceClassification,
Expand All @@ -19,6 +18,7 @@

from .ds_registry import load_and_process_dataset
from .knn import gather_hiddens, knn_average
from .loss import log_confidence_loss
from .roc_auc import roc_auc


Expand All @@ -36,9 +36,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")

outputs = model(**inputs)
loss = torch.nn.functional.cross_entropy(
outputs.logits, torch.stack([1.0 - labels, labels], dim=-1)
)
frac = self.state.global_step / self.state.max_steps
loss = log_confidence_loss(outputs.logits, labels, frac)

return (loss, outputs) if return_outputs else loss


Expand Down Expand Up @@ -73,7 +73,7 @@ def lolcat(lol1, lol2):


def lolconst(lol, const):
return [[const for _ in l] for l in lol]
return [[const for _ in l_] for l_ in lol]


def train(cfg: TrainConfig):
Expand All @@ -83,26 +83,40 @@ def train(cfg: TrainConfig):
strong_tokenizer = AutoTokenizer.from_pretrained(STRONG_NAME)
weak_tokenizer = AutoTokenizer.from_pretrained(cfg.weak_name)

# Make sure that the pad token is set
if strong_tokenizer.pad_token_id is None:
strong_tokenizer.pad_token = strong_tokenizer.eos_token
if weak_tokenizer.pad_token_id is None:
weak_tokenizer.pad_token = weak_tokenizer.eos_token

splits = load_and_process_dataset(
cfg.dataset, split_sizes=dict(train=20_000, test=1_000)
)

if 'txt' in splits["train"].column_names:
if "txt" in splits["train"].column_names:
task = "classify"
elif 'ctx' in splits["train"].column_names:
elif "ctx" in splits["train"].column_names:
task = "generate"
else:
raise ValueError(f"Unrecognized dataset columns: {splits['train'].column_names}")
raise ValueError(
f"Unrecognized dataset columns: {splits['train'].column_names}"
)

def weak_processor(examples):
if task == "generate":
ctx_out = weak_tokenizer(examples["ctx"], truncation=True)
trg_out = weak_tokenizer(examples["target"], truncation=True, add_special_tokens=False)
trg_out = weak_tokenizer(
examples["target"], truncation=True, add_special_tokens=False
)
breakpoint()
return dict(
input_ids=lolcat(ctx_out["input_ids"], trg_out["input_ids"]),
attention_mask=lolcat(ctx_out["attention_mask"], trg_out["attention_mask"]),
labels=lolcat(lolconst(ctx_out["input_ids"], -100), trg_out["input_ids"]),
attention_mask=lolcat(
ctx_out["attention_mask"], trg_out["attention_mask"]
),
labels=lolcat(
lolconst(ctx_out["input_ids"], -100), trg_out["input_ids"]
),
)

out = weak_tokenizer(examples["txt"], truncation=True)
Expand All @@ -120,12 +134,10 @@ def compute_metrics(eval_pred):
auroc=roc_auc(labels, predictions[:, 1]),
)

if task == "generate":
test = splits["test"].select_columns(["ctx", "target"])
train = splits["train"].select_columns(["ctx", "target"])
else:
test = splits["test"].select_columns(["hard_label", "txt"])
train = splits["train"].select_columns(["hard_label", "txt"])
cols = ["ctx", "target"] if task == "generate" else ["hard_label", "txt"]
test = splits["test"].select_columns(cols)
train = splits["train"].select_columns(cols)

weak_test = test.map(weak_processor, batched=True)
weak_train = train.map(weak_processor, batched=True)

Expand All @@ -149,8 +161,8 @@ def compute_metrics(eval_pred):
label_dir = root / "floor/preds"
if label_dir.exists():
print(f"Loading weak labels from {label_dir}")
train_probs = np.load(label_dir / "train.npy")
test_probs = np.load(label_dir / "test.npy")
train_probs = torch.load(label_dir / "train.pt")
test_probs = torch.load(label_dir / "test.pt")
else:
should_train = True
weak_path = root / "floor/best-ckpt"
Expand All @@ -170,11 +182,7 @@ def compute_metrics(eval_pred):
weak_path, torch_dtype="auto"
)

# Make sure the pad token is set
weak_model.config.pad_token_id = (
weak_tokenizer.pad_token_id
) = weak_tokenizer.eos_token_id

weak_model.config.pad_token_id = weak_tokenizer.pad_token_id
trainer = Trainer(
args=training_args,
compute_metrics=compute_metrics,
Expand All @@ -185,7 +193,7 @@ def compute_metrics(eval_pred):
train_dataset=weak_train,
)
if should_train:
print("\033[32m===== Training weak model =====\033[0m")
print("\n\033[32m===== Training weak model =====\033[0m")
trainer.train()
move_best_ckpt(trainer)

Expand All @@ -196,34 +204,32 @@ def compute_metrics(eval_pred):
# Convert to probabilities, then keep only the positive probs
_, train_probs = torch.from_numpy(train_logits).softmax(-1).unbind(-1)
_, test_probs = torch.from_numpy(test_logits).softmax(-1).unbind(-1)
train_probs, test_probs = train_probs.numpy(), test_probs.numpy()

label_dir.mkdir(parents=True, exist_ok=True)
np.save(label_dir / "train.npy", train_probs)
np.save(label_dir / "test.npy", test_probs)
torch.save(train_probs, label_dir / "train.pt")
torch.save(test_probs, label_dir / "test.pt")

def strong_processor(examples):
return strong_tokenizer(examples["txt"], truncation=True)

strong_train = train.map(strong_processor, batched=True).rename_column(
"hard_label", "labels"
)
ceil_test = test.map(strong_processor, batched=True).rename_column(
"hard_label", "labels"
)

strong_ckpt = root / "ceil" / "best-ckpt"
if strong_ckpt.exists():
print(f"Strong ceiling model already exists at {strong_ckpt}")
else:
print("\033[32m===== Training strong ceiling model =====\033[0m")
print("\n\033[32m===== Training strong ceiling model =====\033[0m")
strong_model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
# HuggingFace init for the head is too large
strong_model.score.weight.data *= 0.01

strong_model.config.pad_token_id = (
strong_tokenizer.pad_token_id
) = strong_tokenizer.eos_token_id

def strong_processor(examples):
return strong_tokenizer(examples["txt"], truncation=True)

strong_train = train.map(strong_processor, batched=True)
ceil_test = test.map(strong_processor, batched=True).rename_column(
"hard_label", "labels"
)
strong_model.config.pad_token_id = strong_tokenizer.pad_token_id

training_args.output_dir = str(root / "ceil")
trainer = Trainer(
Expand All @@ -233,23 +239,21 @@ def strong_processor(examples):
eval_dataset=ceil_test,
model=get_peft_model(strong_model, lora_cfg),
tokenizer=strong_tokenizer,
train_dataset=strong_train.rename_column("hard_label", "labels"),
train_dataset=strong_train,
)
trainer.train()
move_best_ckpt(trainer)

print("\033[32m===== Training w2s model =====\033[0m")
print("\n\033[32m===== Training w2s model =====\033[0m")
strong_model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
# HuggingFace init for the head is too large
strong_model.score.weight.data *= 0.01

# Make sure that we use CrossEntropyLoss
strong_model.config.problem_type = "single_label_classification"
strong_model.config.pad_token_id = strong_tokenizer.pad_token_id

# Weak to strong generalization
acts_path = root / "strong/acts.pt"
acts_path = root / "ceil/acts.pt"
if acts_path.exists():
print(f"Loading strong activations from {acts_path}")
train_acts = torch.load(acts_path, map_location=strong_model.device)
Expand All @@ -258,17 +262,26 @@ def strong_processor(examples):
train_acts = gather_hiddens(strong_model, strong_train)
torch.save(train_acts, acts_path)

y = torch.tensor(strong_train["labels"], device=train_acts.device)
labels = knn_average(train_acts, y, 200)
labels = knn_average(train_acts, train_probs.to(train_acts.device), 200)
top = torch.abs(labels - 0.5).topk(len(labels) // 2).indices
strong_train = strong_train.select(top.tolist())

# Check gt metrics every 50 steps during w2s training.
# We can overfit to the weak labels before a single epoch.
training_args.evaluation_strategy = "steps"
training_args.eval_steps = 50
training_args.save_steps = 50

training_args.label_names = ["labels"]
training_args.output_dir = str(root / "w2s")

w2s_train = strong_train.add_column("labels", train_probs)
w2s_train = strong_train.remove_columns("labels").add_column(
"labels", train_probs.numpy()
)
w2s_train = w2s_train.select(top.tolist())

trainer = DistillationTrainer(
args=training_args,
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
model=get_peft_model(strong_model, lora_cfg),
Expand Down

0 comments on commit f57cb4c

Please sign in to comment.