Skip to content

Commit

Permalink
window loss and cfg refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 23, 2024
1 parent 97329e0 commit f423fa0
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 97 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
results/
results*/
wandb/
plots/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
71 changes: 3 additions & 68 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def get_model_and_run_name(model_name, current_name):
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=False,
predict_dict=weak_predict_dict,
balance_batch=cfg.balance_batch,
)

# train strong ceil
Expand All @@ -96,7 +96,7 @@ def get_model_and_run_name(model_name, current_name):
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
balance_batch=cfg.balance_batch,
transfer=False,
)

# load weak predictions
Expand Down Expand Up @@ -133,72 +133,7 @@ def get_model_and_run_name(model_name, current_name):
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
predict_dict=w2s_predict_dict,
logconf_weight=cfg.logconf_weight,
logconf_warmup_steps=cfg.logconf_warmup_steps,
balance_batch=cfg.balance_batch,
)

# load w2s predictions, and balanced-harden them
print("\n\033[32m===== Training (s+w)2s model =====\033[0m")
w2s_preds_root = root / cfg_name / "w2s" / "predictions"
w2s_train_preds_ds = load_from_disk(str(w2s_preds_root / "train")).with_format(
type="torch", columns=["soft_pred"]
)
w2s_val_preds_ds = load_from_disk(str(w2s_preds_root / "val")).with_format(
type="torch", columns=["soft_pred"]
)
prior = torch.tensor(splits["train"]["labels"]).float().mean()
thresh = torch.quantile(w2s_train_preds_ds["soft_pred"], 1 - prior) # type: ignore
# set the label column of train to be (1 - a) * weak + a * hard_w2s
sw2s_train_labels = (
(
(1 - cfg.strong_weight) * torch.tensor(weak_train_preds_ds["soft_pred"]) # type: ignore
+ cfg.strong_weight * (w2s_train_preds_ds["soft_pred"] > thresh).float()
)
.float()
.tolist()
)
sw2s_val_labels = (
(
(1 - cfg.strong_weight) * torch.tensor(weak_val_preds_ds["soft_pred"]) # type: ignore
+ cfg.strong_weight * (w2s_val_preds_ds["soft_pred"] > thresh).float()
)
.float()
.tolist()
)

# train sw2s
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "sw2s")
train_args["run_name"] = run_name
train_args["output_dir"] = str(root / cfg_name / "sw2s")
train_args["learning_rate"] = cfg.strong_lr
sw2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
.remove_columns("labels")
.add_column("labels", sw2s_train_labels) # type: ignore
),
"val": (
splits["val"]
.remove_columns("labels")
.add_column("labels", sw2s_val_labels) # type: ignore
),
"test": splits["test"],
}
)
# assert (w2s_train_preds_ds["id"] == sw2s_ds_dict["train"]["id"])
# assert (w2s_val_preds_ds["id"] == sw2s_ds_dict["val"]["id"])

train(
sw2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
logconf_weight=cfg.logconf_weight,
logconf_warmup_steps=cfg.logconf_warmup_steps,
balance_batch=cfg.balance_batch,
transfer=True,
)


Expand Down
4 changes: 2 additions & 2 deletions run_eight.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
"--eval_every 100 "
"--save_every 100 "
"--save_total_limit 1 "
"--loss logconf "
"--logconf_warmup_steps 80 "
"--balance_batch "
"--logconf_weight 0.5 "
"--strong_weight 0.5 "
"--minibatch_size {minibatch_size} "
"--weak_lr 5e-4 "
"--strong_lr 8e-5 "
'--run_name "stable_balanced_batch" '
'--run_name "basic_w2s" '
)


Expand Down
20 changes: 20 additions & 0 deletions w2s/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import torch


def confidence_window_loss(
logits,
labels,
radius: float = 0.15,
):
"""
Use cross-entropy loss only for the examples where the model is uncertain.
"""
logits = logits.float()
labels = labels.float()

preds = torch.softmax(logits, dim=-1)

uncertain = (preds.max(dim=-1).values < 0.5 + radius)

target = torch.stack([1.0 - labels, labels], dim=1)

return torch.nn.functional.cross_entropy(logits[uncertain], target[uncertain])


def log_confidence_loss(
logits,
labels,
Expand Down
50 changes: 29 additions & 21 deletions w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,53 @@
)

import wandb
from w2s.loss import log_confidence_loss
from w2s.loss import log_confidence_loss, confidence_window_loss
from w2s.model import ModelConfig, init_model_and_tokenizer
from w2s.roc_auc import roc_auc
from w2s.sft_utils import (
clear_mem,
get_gpu_mem_used,
move_best_ckpt,
)
from w2s.sft_config import LossConfig


class CustomLossTrainer(Trainer):
def __init__(
self,
logconf_weight: float,
logconf_warmup_steps: int,
balance_batch: bool,
loss_name: str,
loss_cfg: LossConfig,
transfer: bool,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.logconf_weight = logconf_weight
self.logconf_warmup_steps = logconf_warmup_steps
self.balance_batch = balance_batch
self.loss_name = loss_name
self.loss_cfg = loss_cfg
self.transfer = transfer

def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels").float()

outputs = model(**inputs)

loss = log_confidence_loss(
outputs.logits,
labels,
self.state.global_step,
aux_coef=self.logconf_weight,
warmup_steps=self.logconf_warmup_steps,
balance_batch=self.balance_batch,
)
if self.loss_name == 'logconf':
loss = log_confidence_loss(
outputs.logits,
labels,
self.state.global_step,
aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.),
warmup_steps=self.loss_cfg.logconf_warmup_steps,
balance_batch=self.loss_cfg.balance_batch,
)
elif self.loss_name == 'window':
loss = confidence_window_loss(
outputs.logits,
labels,
radius=self.loss_cfg.radius,
)
else:
raise ValueError(f"Unknown loss function: {self.loss_name}")

return (loss, outputs) if return_outputs else loss

Expand All @@ -57,9 +67,7 @@ def train(
model_cfg: ModelConfig,
train_args: TrainingArguments,
cfg: dict,
logconf_weight: float = 0.0,
logconf_warmup_steps: int = 200,
balance_batch: bool = False,
transfer: bool,
predict_dict: Union[DatasetDict, dict, None] = None,
):
"""
Expand Down Expand Up @@ -105,9 +113,9 @@ def compute_metrics(eval_pred):
)

trainer = CustomLossTrainer(
logconf_weight=logconf_weight,
logconf_warmup_steps=logconf_warmup_steps,
balance_batch=balance_batch,
loss_name=cfg.loss_name,
loss_cfg=cfg.loss,
transfer=transfer,
args=train_args,
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(tokenizer),
Expand Down
29 changes: 24 additions & 5 deletions w2s/sft_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
from dataclasses import dataclass
from typing import Literal, Optional, Union

from simple_parsing import Serializable, field
from simple_parsing import Serializable, field, subgroups


@dataclass
class LossConfig(Serializable):
pass

@dataclass
class LogConfidenceLossConfig(LossConfig):
logconf_weight: float = 0.5
logconf_warmup_steps: int = 200
balance_batch: bool = False

@dataclass
class ConfidenceWindowLossConfig(LossConfig):
radius: float = 0.15

LOSS_CONFIGS = {
"logconf": LogConfidenceLossConfig,
"window": ConfidenceWindowLossConfig,
}

@dataclass
class SFTConfig(Serializable): # TODO: what is this for??
# name of the model to train
Expand All @@ -29,17 +48,15 @@ class SFTConfig(Serializable): # TODO: what is this for??
eval_every: int = 100 # steps
save_every: int = 100 # steps
save_total_limit: Optional[int] = None
logconf_weight: float = 0.5
logconf_warmup_steps: int = 200
balance_batch: bool = False
strong_weight: float = 0.5
loss: LossConfig = subgroups(LOSS_CONFIGS, default="logconf")
weight_decay: float = 0.1
weak_lr: float = 5e-4
strong_lr: float = 8e-5
load_best_model_at_end: bool = True
metric_for_best_model: str = "val_auroc"

greater_is_better: bool = field(init=False)
loss_name: str = field(init=False)

def __post_init__(self):
if "loss" in self.metric_for_best_model:
Expand All @@ -52,6 +69,8 @@ def __post_init__(self):
else:
raise ValueError(f"Unknown metric {self.metric_for_best_model}")

self.loss_name = {LOSS_CONFIGS[k]:k for k in LOSS_CONFIGS}[type(self.loss)]

def to_dict(self):
irrelevant_fields = ["results_folder", "run_name", "minibatch_size"]
return {k: v for k, v in vars(self).items() if k not in irrelevant_fields}

0 comments on commit f423fa0

Please sign in to comment.