Skip to content

Commit

Permalink
finetuning base model with weak plus hardened w2s preds
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 21, 2024
1 parent e57e4f4 commit b36455e
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 417 deletions.
155 changes: 141 additions & 14 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path

import torch
from datasets import DatasetDict, load_from_disk
from simple_parsing import parse
from transformers import (
TrainingArguments,
Expand All @@ -25,21 +27,18 @@ def run_train(cfg: SFTConfig):

root = Path(cfg.results_folder) / cfg.run_name
cfg_name = get_config_foldername(vars(cfg))
model_last = cfg.model_name.split("/")[-1]
train_args = TrainingArguments(
output_dir=str(root / cfg_name),
train_args: dict = dict(
num_train_epochs=cfg.n_epochs,
adam_beta2=0.95,
gradient_accumulation_steps=cfg.batch_size // cfg.minibatch_size,
evaluation_strategy="steps",
eval_strategy="steps",
label_names=["labels"],
load_best_model_at_end=True,
load_best_model_at_end=cfg.load_best_model_at_end,
logging_steps=25,
metric_for_best_model="eval_loss",
greater_is_better=False,
metric_for_best_model=cfg.metric_for_best_model,
greater_is_better=cfg.greater_is_better,
per_device_train_batch_size=cfg.minibatch_size,
per_device_eval_batch_size=cfg.minibatch_size,
run_name=f"{cfg.run_name}-{cfg.dataset}-{model_last}",
save_strategy="steps",
save_total_limit=cfg.save_total_limit,
tf32=True, # Use Tensor Cores even for fp32 matmuls
Expand All @@ -51,15 +50,143 @@ def run_train(cfg: SFTConfig):
save_steps=cfg.save_every,
)

model_cfg = ModelConfig(name=cfg.model_name, enable_lora=not cfg.disable_lora)
def get_model_and_run_name(model_name, current_name):
model_last = model_name.split("/")[-1]
model_cfg = ModelConfig(name=model_name, enable_lora=not cfg.disable_lora)
run_name = f"{current_name}-{cfg.run_name}-{cfg.dataset}-{model_last}"
return model_cfg, run_name

# train weak floor, get predictions
print("\n\033[32m===== Training weak model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.weak_model_name, "weak")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "weak")
weak_ds_dict = DatasetDict(
{
"train": splits["train"],
"val": splits["val"],
"test": splits["test"],
}
)
weak_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
weak_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
predict_dict=weak_predict_dict,
)

# train strong ceil
print("\n\033[32m===== Training strong model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "strong")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "strong")
strong_ds_dict = DatasetDict(
{
"train": splits["train"],
"val": splits["val"],
"test": splits["test"],
}
)
train(strong_ds_dict, model_cfg, TrainingArguments(**train_args), cfg.to_dict())

# load weak predictions
weak_preds_root = root / cfg_name / "weak" / "predictions"
weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train"))
weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val"))

# train w2s with logconf, get predictions
print("\n\033[32m===== Training w2s model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "w2s")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "w2s")
w2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
.remove_columns("labels")
.add_column("labels", weak_train_preds_ds["soft_pred"]) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", weak_val_preds_ds["soft_pred"])
), # type: ignore
"test": splits["test"],
}
)
# assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"])
# assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"])
w2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
w2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
predict_dict=w2s_predict_dict,
logconf_weight=cfg.logconf_weight,
logconf_warmup_steps=cfg.logconf_warmup_steps,
)

# load w2s predictions, and balanced-harden them
print("\n\033[32m===== Training (s+w)2s model =====\033[0m")
w2s_preds_root = root / cfg_name / "w2s" / "predictions"
w2s_train_preds_ds = load_from_disk(str(w2s_preds_root / "train")).with_format(
type="torch", columns=["soft_pred"]
)
w2s_val_preds_ds = load_from_disk(str(w2s_preds_root / "val")).with_format(
type="torch", columns=["soft_pred"]
)
prior = torch.tensor(splits["train"]["labels"]).float().mean()
thresh = torch.quantile(w2s_train_preds_ds["soft_pred"], 1 - prior) # type: ignore
# set the label column of train to be (1 - a) * weak + a * hard_w2s
sw2s_train_labels = (
(
(1 - cfg.strong_weight) * torch.tensor(weak_train_preds_ds["soft_pred"]) # type: ignore
+ cfg.strong_weight * (w2s_train_preds_ds["soft_pred"] > thresh).float()
)
.float()
.tolist()
)
sw2s_val_labels = (
(
(1 - cfg.strong_weight) * torch.tensor(weak_val_preds_ds["soft_pred"]) # type: ignore
+ cfg.strong_weight * (w2s_val_preds_ds["soft_pred"] > thresh).float()
)
.float()
.tolist()
)

# train sw2s on train with logconf
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "sw2s")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "sw2s")
sw2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
.remove_columns("labels")
.add_column("labels", sw2s_train_labels) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", sw2s_val_labels) # type: ignore
),
"test": splits["test"],
}
)
# assert (w2s_train_preds_ds["id"] == sw2s_ds_dict["train"]["id"])
# assert (w2s_val_preds_ds["id"] == sw2s_ds_dict["val"]["id"])

train(
splits,
sw2s_ds_dict,
model_cfg,
train_args,
cfg.loss,
cfg.store_pre_hiddens,
cfg.store_post_hiddens,
TrainingArguments(**train_args),
cfg.to_dict(),
logconf_weight=cfg.logconf_weight,
logconf_warmup_steps=0,
)


Expand Down
57 changes: 57 additions & 0 deletions run_eight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import subprocess
from multiprocessing import Process

# Define the datasets and respective GPU ids
datasets = [
"boolq",
"anli-r2",
"cosmos_qa",
"mc_taco",
"sciq",
"paws",
"twitter-sentiment",
"wic",
]

gpu_ids = range(len(datasets))

# Define the base command
base_command = (
"CUDA_VISIBLE_DEVICES={gpu_id} "
"python run.py "
"--dataset {dataset} "
"--weak_model_name Qwen/Qwen1.5-0.5B "
"--strong_model_name meta-llama/Meta-Llama-3-8B "
"--n_epochs 2 "
"--n_train 20_000 "
"--n_val 500 "
"--n_test 1000 "
"--n_predict 0 "
"--eval_every 25 "
"--save_every 25 "
"--logconf_warmup_steps 200 "
"--logconf_weight 0.5 "
"--strong_weight 0.5 "
"--minibatch_size 4 "
'--run_name "logconf_no_warmup" '
)


def run_command(command):
subprocess.run(command, shell=True, check=True)


# List to hold processes
processes = []

# Loop over datasets and gpu_ids
for dataset, gpu_id in zip(datasets, gpu_ids):
command = base_command.format(gpu_id=gpu_id, dataset=dataset)
print(f"Running command: {command}") # Debug print
p = Process(target=run_command, args=(command,))
p.start()
processes.append(p)

# Wait for all processes to complete
for p in processes:
p.join()
60 changes: 0 additions & 60 deletions train_transformer_reporter.py

This file was deleted.

2 changes: 1 addition & 1 deletion underspec/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def log_confidence_loss(
logits = logits.float()
labels = labels.float()

coef = aux_coef * min(1.0, step / warmup_steps)
coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef
preds = torch.softmax(logits, dim=-1)

threshold = torch.quantile(preds[:, 0], labels.mean())
Expand Down
53 changes: 1 addition & 52 deletions underspec/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import warnings
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional

import torch
from peft import LoraConfig, get_peft_model
from transformers import (
AutoModelForSequenceClassification,
Expand Down Expand Up @@ -65,52 +63,3 @@ def init_model_and_tokenizer(cfg: ModelConfig):
model = get_peft_model(model, lora_cfg)

return model, tokenizer


class Predictor(ABC):
"""
The strong "predictor", using the terminology of the original ELK report
https://docs.google.com/document/d/1WwsnJQstPq91_Yh-Ch2XRL8H_EpsnjrC1dwZXR37PC8/edit#heading=h.kkaua0hwmp1d
this is the model we would like to elicit latent knowledge from using the reporter
"""

def __call__(
self, inputs, output_hidden_states=False
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor]]]:
"""
This takes in a batch of inputs and returns the logodds of the model's predictions.
If output_hidden_states is True, it also returns the hidden states of the model (second)
Each of the `num_layers` hiddens are tensors of shape [n, hidden_size]
"""
...

def get_cfg_summary(self) -> dict[str, str | int | float]:
"""A summary of the method that approximately uniquely identifies it.
It should include a name and all the important hyperparameters."""
...


class TransformerPredictor(Predictor):
def __init__(self, cfg: ModelConfig):
self.model, self.tokenizer = init_model_and_tokenizer(cfg)

def __call__(self, inputs, output_hidden_states=False):
# inputs are text strings
assert isinstance(inputs, list)
# ...ModelForSequenceClassification makes sure to score hiddens
# from the last non-padding token position
input_ids = self.model.tokenizer(
inputs, padding=True, truncation=True, return_tensors="pt"
)["input_ids"]

outputs = self.model(input_ids, output_hidden_states=True)

hiddens = torch.stack(
outputs.hidden_states
) # [num_layers, n, seq_len, hidden_size]
seq_lens = input_ids.ne(self.tokenizer.pad_token_id).sum(dim=-1)
last_non_pad_idx = seq_lens - 1
last_hidden_states = hiddens[:, torch.arange(len(inputs)), last_non_pad_idx, :]

logits = outputs.logits
return logits, last_hidden_states.unbind(0) if output_hidden_states else logits
Loading

0 comments on commit b36455e

Please sign in to comment.