Skip to content

Commit

Permalink
mistral and llama3; instantiate new configs; lower learning rate; 2 e…
Browse files Browse the repository at this point in the history
…pochs; longer logconf warmup; fix model loading
  • Loading branch information
AlexTMallen committed May 16, 2024
1 parent add1b42 commit f5991ea
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 77 deletions.
141 changes: 80 additions & 61 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import wandb
from w2s.ds_registry import load_and_process_dataset
from w2s.knn import gather_hiddens, topofilter
from w2s.loss import log_confidence_loss
from w2s.roc_auc import roc_auc

disable_caching()
Expand Down Expand Up @@ -54,15 +55,14 @@ def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels").float()

outputs = model(**inputs)
# frac = self.state.global_step / self.state.max_steps
# loss = log_confidence_loss(outputs.logits, labels, frac)
loss = log_confidence_loss(outputs.logits, labels, self.state.global_step)

labels = torch.stack([1.0 - labels, labels], dim=-1)
loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
# labels = torch.stack([1.0 - labels, labels], dim=-1)
# loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
return (loss, outputs) if return_outputs else loss


# Works for both Llama and Qwen architectures
# Works for Llama, Mistral, and Qwen architectures
LORA_MODULES = [
"gate_proj",
"down_proj",
Expand All @@ -89,25 +89,28 @@ def move_best_ckpt(trainer: Trainer):
def train(cfg: TrainConfig):
lora_cfg = LoraConfig(target_modules=LORA_MODULES)

STRONG_NAME = "meta-llama/Meta-Llama-3-8B"
strong_tokenizer = AutoTokenizer.from_pretrained(STRONG_NAME)
# for 2 strong models we do WEAK -> STRONG[0] -> STRONG[1] -> STRONG[0] -> ...
STRONG_NAMES = ["meta-llama/Meta-Llama-3-8B", "mistralai/Mistral-7B-v0.1"]
MAIN_STRONG_NAME = STRONG_NAMES[0]
strong_tokenizers = {k: AutoTokenizer.from_pretrained(k) for k in STRONG_NAMES}
weak_tokenizer = AutoTokenizer.from_pretrained(cfg.weak_name)

# Make sure that the pad token is set
if strong_tokenizer.pad_token_id is None:
strong_tokenizer.pad_token = strong_tokenizer.eos_token
for t in strong_tokenizers.values():
if t.pad_token_id is None:
t.pad_token = t.eos_token
if weak_tokenizer.pad_token_id is None:
weak_tokenizer.pad_token = weak_tokenizer.eos_token

splits = load_and_process_dataset(
cfg.dataset, split_sizes=dict(train=20_000, test=1_000)
)

def init_strong_model():
def init_strong_model(name):
model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
name, torch_dtype="auto", device_map={"": "cuda"}
)
model.config.pad_token_id = strong_tokenizer.pad_token_id
model.config.pad_token_id = strong_tokenizers[name].pad_token_id # type: ignore
model.score.weight.data *= 0.01
return model

Expand Down Expand Up @@ -136,8 +139,9 @@ def compute_metrics(eval_pred):
)

root = Path("results") / cfg.dataset
training_args = TrainingArguments(
str(root / "floor"),
train_cfg = dict(
output_dir=str(root / "floor"),
num_train_epochs=2,
adam_beta2=0.95,
gradient_accumulation_steps=8 // cfg.minibatch_size,
evaluation_strategy="epoch",
Expand All @@ -153,6 +157,7 @@ def compute_metrics(eval_pred):
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=100,
weight_decay=0.01,
learning_rate=3e-5,
)

# Gather weak labels
Expand Down Expand Up @@ -181,9 +186,9 @@ def compute_metrics(eval_pred):
weak_path, torch_dtype="auto"
)

weak_model.config.pad_token_id = weak_tokenizer.pad_token_id
weak_model.config.pad_token_id = weak_tokenizer.pad_token_id # type: ignore
trainer = Trainer(
args=training_args,
args=TrainingArguments(**train_cfg), # type: ignore
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(weak_tokenizer),
eval_dataset=weak_test,
Expand All @@ -209,59 +214,69 @@ def compute_metrics(eval_pred):
torch.save(train_probs, label_dir / "train.pt")
torch.save(test_probs, label_dir / "test.pt")

def strong_processor(examples):
return strong_tokenizer(examples["txt"], truncation=True)
def strong_processor(examples, tokenizer):
return tokenizer(examples["txt"], truncation=True)

strong_train = (
train.map(strong_processor, batched=True)
strong_trains = {
name: train.map(
strong_processor,
batched=True,
fn_kwargs={"tokenizer": strong_tokenizers[name]},
)
.rename_column("hard_label", "labels")
.cast_column("labels", Value("int64"))
)
ceil_test = (
test.map(strong_processor, batched=True)
for name in STRONG_NAMES
}
ceil_tests = {
name: test.map(
strong_processor,
batched=True,
fn_kwargs={"tokenizer": strong_tokenizers[name]},
)
.rename_column("hard_label", "labels")
.cast_column("labels", Value("int64"))
)
for name in STRONG_NAMES
}

strong_ckpt = root / "ceil" / "best-ckpt"
if strong_ckpt.exists():
print(f"Strong ceiling model already exists at {strong_ckpt}")
else:
print("\n\033[32m===== Training strong ceiling model =====\033[0m")
training_args.output_dir = str(root / "ceil")
training_args.run_name = cfg.dataset + "/ceil" + cfg.run_name
train_cfg["output_dir"] = str(root / "ceil")
train_cfg["run_name"] = cfg.dataset + "/ceil" + cfg.run_name

trainer = Trainer(
args=training_args,
args=TrainingArguments(**train_cfg), # type: ignore
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
model=get_peft_model(init_strong_model(), lora_cfg),
tokenizer=strong_tokenizer,
train_dataset=strong_train,
data_collator=DataCollatorWithPadding(strong_tokenizers[MAIN_STRONG_NAME]),
eval_dataset=ceil_tests[MAIN_STRONG_NAME],
model=get_peft_model(init_strong_model(MAIN_STRONG_NAME), lora_cfg),
tokenizer=strong_tokenizers[MAIN_STRONG_NAME],
train_dataset=strong_trains[MAIN_STRONG_NAME],
)
trainer.train()
wandb.finish()
move_best_ckpt(trainer)

# Weak to strong generalization
w2s_train = strong_train.remove_columns("labels")
w2s_train = strong_trains[MAIN_STRONG_NAME].remove_columns("labels")
w2s_train = w2s_train.add_column("labels", train_probs.numpy())

# Check gt metrics every 100 steps during w2s training.
# We can overfit to the weak labels before a single epoch.
training_args.evaluation_strategy = "steps"
training_args.eval_steps = 100
training_args.save_steps = 100

training_args.label_names = ["labels"]
training_args.output_dir = str(root / "w2s") + cfg.run_name
training_args.run_name = cfg.dataset + "/w2s" + cfg.run_name
train_cfg["evaluation_strategy"] = "steps"
train_cfg["save_strategy"] = "steps"
train_cfg["eval_steps"] = 100
train_cfg["save_steps"] = 100
train_cfg["label_names"] = ["labels"]
train_cfg["output_dir"] = str(root / ("w2s" + cfg.run_name))
train_cfg["run_name"] = cfg.dataset + "/w2s" + cfg.run_name

should_train = True
w2s_ckpt = root / "ceil" / "best-ckpt"
w2s_ckpt = root / ("w2s" + cfg.run_name) / "best-ckpt"
if w2s_ckpt.exists():
print(f"W2S model already exists at {strong_ckpt}")
print(f"W2S model already exists at {w2s_ckpt}")

w2s_model = AutoPeftModelForSequenceClassification.from_pretrained(
w2s_ckpt, torch_dtype="auto", device_map={"": "cuda"}
Expand All @@ -270,15 +285,15 @@ def strong_processor(examples):
else:
print("\n\033[32m===== Training w2s model =====\033[0m")

strong_model = init_strong_model()
strong_model = init_strong_model(MAIN_STRONG_NAME)
if cfg.contamination > 0.0:
acts_path = root / "ceil/acts.pt"
if acts_path.exists():
print(f"Loading strong activations from {acts_path}")
train_acts = torch.load(acts_path, map_location=strong_model.device)
else:
print("Gathering strong activations")
train_acts = gather_hiddens(strong_model, strong_train)
train_acts = gather_hiddens(strong_model, w2s_train)
torch.save(train_acts, acts_path)

y = train_probs.to(train_acts.device)
Expand All @@ -287,14 +302,14 @@ def strong_processor(examples):

w2s_model = get_peft_model(strong_model, lora_cfg)

w2s_model.config.pad_token_id = strong_tokenizer.pad_token_id
w2s_model.config.pad_token_id = strong_tokenizers[MAIN_STRONG_NAME].pad_token_id # type: ignore # noqa
trainer = DistillationTrainer(
args=training_args,
args=TrainingArguments(**train_cfg), # type: ignore
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
data_collator=DataCollatorWithPadding(strong_tokenizers[MAIN_STRONG_NAME]),
eval_dataset=ceil_tests[MAIN_STRONG_NAME],
model=w2s_model,
tokenizer=strong_tokenizer,
tokenizer=strong_tokenizers[MAIN_STRONG_NAME],
train_dataset=w2s_train,
)
if should_train:
Expand All @@ -304,13 +319,15 @@ def strong_processor(examples):

# Save memory
del w2s_model
preds_path = root / "ceil/preds.pt"

preds_path = root / ("w2s" + cfg.run_name) / "preds.pt"

# Strong to strong generalization
for i in range(cfg.s2s_iter):
print(f"\n\033[32m===== Self-distillation iter {i + 1} =====\033[0m")
strong_name = STRONG_NAMES[(i + 1) % len(STRONG_NAMES)]
print(f"\n\033[32m===== S2S-distillation {i + 1} ({strong_name}) =====\033[0m")

# Gather strong activations
# Gather strong labels
if preds_path.exists():
print(f"Loading strong preds from {preds_path}")
train_probs = torch.load(preds_path)
Expand All @@ -322,21 +339,23 @@ def strong_processor(examples):

del trainer

w2s_train = w2s_train.remove_columns("labels").add_column(
"labels", train_probs.numpy()
w2s_train = (
strong_trains[strong_name]
.remove_columns("labels")
.add_column("labels", train_probs.numpy())
)

name = f"s2s_iter{i + 1}" + cfg.run_name
training_args.output_dir = str(root / name)
training_args.run_name = cfg.dataset + "/" + name
train_cfg["output_dir"] = str(root / name)
train_cfg["run_name"] = cfg.dataset + "/" + name

trainer = DistillationTrainer(
args=training_args,
args=TrainingArguments(**train_cfg), # type: ignore
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(strong_tokenizer),
eval_dataset=ceil_test,
model=get_peft_model(init_strong_model(), lora_cfg),
tokenizer=strong_tokenizer,
data_collator=DataCollatorWithPadding(strong_tokenizers[strong_name]),
eval_dataset=ceil_tests[strong_name],
model=get_peft_model(init_strong_model(strong_name), lora_cfg),
tokenizer=strong_tokenizers[strong_name],
train_dataset=w2s_train,
)
trainer.train()
Expand Down
36 changes: 23 additions & 13 deletions w2s/ds_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,27 @@ def hf_loader(*hf_name, split_names=None, n_test=None):
If `n_test` is provided, it will concatenate all splits together
and then take a deterministic test set of size `n_test` from it.
"""
if n_test is not None:
assert split_names is None
ds = hf_load_dataset(*hf_name)
if isinstance(ds, HfDatasetDict):
ds = concatenate_datasets(ds.values()) # type: ignore
assert isinstance(ds, HfDataset)
splits = ds.train_test_split(test_size=n_test, seed=0)
return lambda split: splits[split]

if split_names is None:
split_names = dict()
# this thunk avoids loading datasets at import time
def thunk(split):
nonlocal split_names
if n_test is not None:
assert split_names is None
ds = hf_load_dataset(*hf_name)
if isinstance(ds, HfDatasetDict):
ds = concatenate_datasets(ds.values()) # type: ignore
assert isinstance(ds, HfDataset)
# the seed is fixed so that all runs use the same test pool
splits = ds.train_test_split(test_size=n_test, seed=0)

return lambda split: hf_load_dataset(*hf_name, split=split_names.get(split, split))
return splits[split]

if split_names is None:
split_names = dict()

return hf_load_dataset(*hf_name, split=split_names.get(split, split))

return thunk


##########
Expand Down Expand Up @@ -253,7 +261,8 @@ def format_dream(ex, rng):

ans = rng.choice(distractors)

txt = f"{'\n'.join(ex['dialogue'])}\n\nQ: {ex['question']} A: {ans}"
joined = "\n".join(ex["dialogue"])
txt = f"{joined}\n\nQ: {ex['question']} A: {ans}"
return dict(txt=txt, hard_label=hard_label)


Expand Down Expand Up @@ -413,7 +422,8 @@ def format_openbookqa(ex, rng):
choices = [
f"{a}) {t}" for a, t in zip(ex["choices"]["label"], ex["choices"]["text"])
]
txt = f"Q: {ex['question_stem']}\n\nChoices:\n{'\n'.join(choices)}\n\nAnswer: {ans}"
joined = "\n".join(choices)
txt = f"Q: {ex['question_stem']}\n\nChoices:\n{joined}\n\nAnswer: {ans}"
return dict(txt=txt, hard_label=hard_label)


Expand Down
6 changes: 3 additions & 3 deletions w2s/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
def log_confidence_loss(
logits,
labels,
step_frac: float,
warmup_frac: float = 0.1,
step: int,
warmup_steps: int = 200,
aux_coef: float = 0.5,
):
logits = logits.float()
labels = labels.float()

coef = aux_coef * min(1.0, step_frac / warmup_frac)
coef = aux_coef * min(1.0, step / warmup_steps)
preds = torch.softmax(logits, dim=-1)

threshold = torch.quantile(preds[:, 0], labels.mean())
Expand Down

0 comments on commit f5991ea

Please sign in to comment.