Skip to content

Commit

Permalink
balanced batch
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 22, 2024
1 parent 9df4eac commit 97329e0
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 53 deletions.
21 changes: 15 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
TrainingArguments,
)

from underspec.ds_registry import load_and_process_dataset
from underspec.model import ModelConfig
from underspec.sft import train
from underspec.sft_config import SFTConfig
from underspec.utils import get_config_foldername
from w2s.ds_registry import load_and_process_dataset
from w2s.model import ModelConfig
from w2s.sft import train
from w2s.sft_config import SFTConfig
from w2s.utils import get_config_foldername


def run_train(cfg: SFTConfig):
Expand Down Expand Up @@ -75,6 +75,7 @@ def get_model_and_run_name(model_name, current_name):
TrainingArguments(**train_args),
cfg.to_dict(),
predict_dict=weak_predict_dict,
balance_batch=cfg.balance_batch,
)

# train strong ceil
Expand All @@ -90,7 +91,13 @@ def get_model_and_run_name(model_name, current_name):
"test": splits["test"],
}
)
train(strong_ds_dict, model_cfg, TrainingArguments(**train_args), cfg.to_dict())
train(
strong_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
balance_batch=cfg.balance_batch,
)

# load weak predictions
weak_preds_root = root / cfg_name / "weak" / "predictions"
Expand Down Expand Up @@ -129,6 +136,7 @@ def get_model_and_run_name(model_name, current_name):
predict_dict=w2s_predict_dict,
logconf_weight=cfg.logconf_weight,
logconf_warmup_steps=cfg.logconf_warmup_steps,
balance_batch=cfg.balance_batch,
)

# load w2s predictions, and balanced-harden them
Expand Down Expand Up @@ -190,6 +198,7 @@ def get_model_and_run_name(model_name, current_name):
cfg.to_dict(),
logconf_weight=cfg.logconf_weight,
logconf_warmup_steps=cfg.logconf_warmup_steps,
balance_batch=cfg.balance_batch,
)


Expand Down
5 changes: 3 additions & 2 deletions run_eight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Define the datasets and respective GPU ids
configs = [
("boolq", 2),
("anli-r2", 16),
("anli-r2", 8),
("cosmos_qa", 4),
("mc_taco", 4),
("sciq", 4),
Expand All @@ -31,12 +31,13 @@
"--save_every 100 "
"--save_total_limit 1 "
"--logconf_warmup_steps 80 "
"--balance_batch "
"--logconf_weight 0.5 "
"--strong_weight 0.5 "
"--minibatch_size {minibatch_size} "
"--weak_lr 5e-4 "
"--strong_lr 8e-5 "
'--run_name "3rd_term_warmup" '
'--run_name "stable_balanced_batch" '
)


Expand Down
24 changes: 0 additions & 24 deletions underspec/loss.py

This file was deleted.

40 changes: 29 additions & 11 deletions viz.ipynb

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.
70 changes: 70 additions & 0 deletions w2s/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch


def log_confidence_loss(
logits,
labels,
step: int,
warmup_steps: int = 200,
aux_coef: float = 0.5,
balance_batch: bool = False,
):
"""
This is similar to the loss in Burns et al., except that it also optionally
balances the labels by mean-subtracting in log-odds space.
"""
logits = logits.float()
labels = labels.float()
if balance_batch:
logodds_labels = torch.log(labels + 1e-7) - torch.log(1 - labels + 1e-7)
labels = torch.sigmoid(logodds_labels - logodds_labels.mean())
prior = 0.5
else:
prior = labels.mean()

coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef
preds = torch.softmax(logits, dim=-1)

threshold = torch.quantile(preds[:, 0], prior)
strong_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
)
labels_binary = torch.stack([1.0 - labels, labels], dim=1)
target = labels_binary * (1 - coef) + strong_preds.detach() * coef
return torch.nn.functional.cross_entropy(logits, target)


def log_confidence_loss2(
logits,
labels,
step: int,
warmup_steps: int = 200,
aux_coef: float = 0.5,
balance_batch: bool = False,
):
"""
This one uses a batch-independent threshold of 0.5, and then finally optionally balances
the batch by mean-subtracting the log-odds of the target.
"""
logits = logits.float()
labels = labels.float()

coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef
preds = torch.softmax(logits, dim=-1)

threshold = 0.5
strong_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
)
labels_binary = torch.stack([1.0 - labels, labels], dim=1)
target = labels_binary * (1 - coef) + strong_preds.detach() * coef

if balance_batch:
logodds_target = torch.log(target) - torch.log1p(-target)
target = torch.sigmoid(
logodds_target - logodds_target.mean(dim=0, keepdim=True)
)

return torch.nn.functional.cross_entropy(logits, target)
File renamed without changes.
File renamed without changes.
26 changes: 17 additions & 9 deletions underspec/sft.py → w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
)

import wandb
from underspec.loss import log_confidence_loss
from underspec.model import ModelConfig, init_model_and_tokenizer
from underspec.roc_auc import roc_auc
from underspec.sft_utils import (
from w2s.loss import log_confidence_loss
from w2s.model import ModelConfig, init_model_and_tokenizer
from w2s.roc_auc import roc_auc
from w2s.sft_utils import (
clear_mem,
get_gpu_mem_used,
move_best_ckpt,
Expand All @@ -23,11 +23,17 @@

class CustomLossTrainer(Trainer):
def __init__(
self, logconf_weight: float, logconf_warmup_steps: int, *args, **kwargs
self,
logconf_weight: float,
logconf_warmup_steps: int,
balance_batch: bool,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.logconf_weight = logconf_weight
self.logconf_warmup_steps = logconf_warmup_steps
self.balance_batch = balance_batch

def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels").float()
Expand All @@ -40,6 +46,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
self.state.global_step,
aux_coef=self.logconf_weight,
warmup_steps=self.logconf_warmup_steps,
balance_batch=self.balance_batch,
)

return (loss, outputs) if return_outputs else loss
Expand All @@ -52,19 +59,19 @@ def train(
cfg: dict,
logconf_weight: float = 0.0,
logconf_warmup_steps: int = 200,
balance_batch: bool = False,
predict_dict: Union[DatasetDict, dict, None] = None,
):
"""
ds_dict: DatasetDict with splits for train, val, test, and (optionally) predict,
with columns "txt" and "labels"
model_cfg: ModelConfig with the model name and whether to enable LoRA
train_args: TrainingArguments with the training hyperparameters
logconf_weight: the weight for the log confidence loss
store_pre_hiddens: whether to store the hiddens (all layers,
final token position, on train set) before training
store_post_hiddens: whether to store the hiddens after training
cfg: a dictionary containing all the relevant details for reproducibility.
This will be updated with your train_args and model_cfg before saving.
logconf_weight: the weight for the log confidence loss
logconf_warmup_steps: the number of steps to linearly increase the logconf_weight
balance_batch: whether to balance the batch with the log confidence loss
This function trains a model on ds_dict["train"], uses ds_dict["val"] for early stopping,
and evaluates on ds_dict["test"].
Expand Down Expand Up @@ -100,6 +107,7 @@ def compute_metrics(eval_pred):
trainer = CustomLossTrainer(
logconf_weight=logconf_weight,
logconf_warmup_steps=logconf_warmup_steps,
balance_batch=balance_batch,
args=train_args,
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(tokenizer),
Expand Down
1 change: 1 addition & 0 deletions underspec/sft_config.py → w2s/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SFTConfig(Serializable): # TODO: what is this for??
save_total_limit: Optional[int] = None
logconf_weight: float = 0.5
logconf_warmup_steps: int = 200
balance_batch: bool = False
strong_weight: float = 0.5
weight_decay: float = 0.1
weak_lr: float = 5e-4
Expand Down
2 changes: 1 addition & 1 deletion underspec/sft_utils.py → w2s/sft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import tqdm
from transformers import PretrainedConfig, Trainer

from underspec.utils import assert_type
from w2s.utils import assert_type


@torch.no_grad()
Expand Down
File renamed without changes.

0 comments on commit 97329e0

Please sign in to comment.