From f423fa02438255a7db020b873a8b13768a245ad6 Mon Sep 17 00:00:00 2001 From: Adam Scherlis Date: Thu, 23 May 2024 21:03:33 +0000 Subject: [PATCH] window loss and cfg refactor --- .gitignore | 3 +- run.py | 71 ++--------------------------------------------- run_eight.py | 4 +-- w2s/loss.py | 20 +++++++++++++ w2s/sft.py | 50 +++++++++++++++++++-------------- w2s/sft_config.py | 29 +++++++++++++++---- 6 files changed, 80 insertions(+), 97 deletions(-) diff --git a/.gitignore b/.gitignore index 55ad787..656bae2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ -results/ +results*/ wandb/ +plots/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/run.py b/run.py index 13482ab..8b5cdb9 100644 --- a/run.py +++ b/run.py @@ -74,8 +74,8 @@ def get_model_and_run_name(model_name, current_name): model_cfg, TrainingArguments(**train_args), cfg.to_dict(), + transfer=False, predict_dict=weak_predict_dict, - balance_batch=cfg.balance_batch, ) # train strong ceil @@ -96,7 +96,7 @@ def get_model_and_run_name(model_name, current_name): model_cfg, TrainingArguments(**train_args), cfg.to_dict(), - balance_batch=cfg.balance_batch, + transfer=False, ) # load weak predictions @@ -133,72 +133,7 @@ def get_model_and_run_name(model_name, current_name): model_cfg, TrainingArguments(**train_args), cfg.to_dict(), - predict_dict=w2s_predict_dict, - logconf_weight=cfg.logconf_weight, - logconf_warmup_steps=cfg.logconf_warmup_steps, - balance_batch=cfg.balance_batch, - ) - - # load w2s predictions, and balanced-harden them - print("\n\033[32m===== Training (s+w)2s model =====\033[0m") - w2s_preds_root = root / cfg_name / "w2s" / "predictions" - w2s_train_preds_ds = load_from_disk(str(w2s_preds_root / "train")).with_format( - type="torch", columns=["soft_pred"] - ) - w2s_val_preds_ds = load_from_disk(str(w2s_preds_root / "val")).with_format( - type="torch", columns=["soft_pred"] - ) - prior = torch.tensor(splits["train"]["labels"]).float().mean() - thresh = torch.quantile(w2s_train_preds_ds["soft_pred"], 1 - prior) # type: ignore - # set the label column of train to be (1 - a) * weak + a * hard_w2s - sw2s_train_labels = ( - ( - (1 - cfg.strong_weight) * torch.tensor(weak_train_preds_ds["soft_pred"]) # type: ignore - + cfg.strong_weight * (w2s_train_preds_ds["soft_pred"] > thresh).float() - ) - .float() - .tolist() - ) - sw2s_val_labels = ( - ( - (1 - cfg.strong_weight) * torch.tensor(weak_val_preds_ds["soft_pred"]) # type: ignore - + cfg.strong_weight * (w2s_val_preds_ds["soft_pred"] > thresh).float() - ) - .float() - .tolist() - ) - - # train sw2s - model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "sw2s") - train_args["run_name"] = run_name - train_args["output_dir"] = str(root / cfg_name / "sw2s") - train_args["learning_rate"] = cfg.strong_lr - sw2s_ds_dict = DatasetDict( - { - "train": ( - splits["train"] - .remove_columns("labels") - .add_column("labels", sw2s_train_labels) # type: ignore - ), - "val": ( - splits["val"] - .remove_columns("labels") - .add_column("labels", sw2s_val_labels) # type: ignore - ), - "test": splits["test"], - } - ) - # assert (w2s_train_preds_ds["id"] == sw2s_ds_dict["train"]["id"]) - # assert (w2s_val_preds_ds["id"] == sw2s_ds_dict["val"]["id"]) - - train( - sw2s_ds_dict, - model_cfg, - TrainingArguments(**train_args), - cfg.to_dict(), - logconf_weight=cfg.logconf_weight, - logconf_warmup_steps=cfg.logconf_warmup_steps, - balance_batch=cfg.balance_batch, + transfer=True, ) diff --git a/run_eight.py b/run_eight.py index 932de20..cc7fd25 100644 --- a/run_eight.py +++ b/run_eight.py @@ -30,14 +30,14 @@ "--eval_every 100 " "--save_every 100 " "--save_total_limit 1 " + "--loss logconf " "--logconf_warmup_steps 80 " "--balance_batch " "--logconf_weight 0.5 " - "--strong_weight 0.5 " "--minibatch_size {minibatch_size} " "--weak_lr 5e-4 " "--strong_lr 8e-5 " - '--run_name "stable_balanced_batch" ' + '--run_name "basic_w2s" ' ) diff --git a/w2s/loss.py b/w2s/loss.py index e36a263..13fefff 100644 --- a/w2s/loss.py +++ b/w2s/loss.py @@ -1,6 +1,26 @@ import torch +def confidence_window_loss( + logits, + labels, + radius: float = 0.15, +): + """ + Use cross-entropy loss only for the examples where the model is uncertain. + """ + logits = logits.float() + labels = labels.float() + + preds = torch.softmax(logits, dim=-1) + + uncertain = (preds.max(dim=-1).values < 0.5 + radius) + + target = torch.stack([1.0 - labels, labels], dim=1) + + return torch.nn.functional.cross_entropy(logits[uncertain], target[uncertain]) + + def log_confidence_loss( logits, labels, diff --git a/w2s/sft.py b/w2s/sft.py index 1cb95f7..8e3e279 100644 --- a/w2s/sft.py +++ b/w2s/sft.py @@ -11,7 +11,7 @@ ) import wandb -from w2s.loss import log_confidence_loss +from w2s.loss import log_confidence_loss, confidence_window_loss from w2s.model import ModelConfig, init_model_and_tokenizer from w2s.roc_auc import roc_auc from w2s.sft_utils import ( @@ -19,35 +19,45 @@ get_gpu_mem_used, move_best_ckpt, ) +from w2s.sft_config import LossConfig class CustomLossTrainer(Trainer): def __init__( self, - logconf_weight: float, - logconf_warmup_steps: int, - balance_batch: bool, + loss_name: str, + loss_cfg: LossConfig, + transfer: bool, *args, **kwargs, ): super().__init__(*args, **kwargs) - self.logconf_weight = logconf_weight - self.logconf_warmup_steps = logconf_warmup_steps - self.balance_batch = balance_batch + self.loss_name = loss_name + self.loss_cfg = loss_cfg + self.transfer = transfer def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels").float() outputs = model(**inputs) - loss = log_confidence_loss( - outputs.logits, - labels, - self.state.global_step, - aux_coef=self.logconf_weight, - warmup_steps=self.logconf_warmup_steps, - balance_batch=self.balance_batch, - ) + if self.loss_name == 'logconf': + loss = log_confidence_loss( + outputs.logits, + labels, + self.state.global_step, + aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.), + warmup_steps=self.loss_cfg.logconf_warmup_steps, + balance_batch=self.loss_cfg.balance_batch, + ) + elif self.loss_name == 'window': + loss = confidence_window_loss( + outputs.logits, + labels, + radius=self.loss_cfg.radius, + ) + else: + raise ValueError(f"Unknown loss function: {self.loss_name}") return (loss, outputs) if return_outputs else loss @@ -57,9 +67,7 @@ def train( model_cfg: ModelConfig, train_args: TrainingArguments, cfg: dict, - logconf_weight: float = 0.0, - logconf_warmup_steps: int = 200, - balance_batch: bool = False, + transfer: bool, predict_dict: Union[DatasetDict, dict, None] = None, ): """ @@ -105,9 +113,9 @@ def compute_metrics(eval_pred): ) trainer = CustomLossTrainer( - logconf_weight=logconf_weight, - logconf_warmup_steps=logconf_warmup_steps, - balance_batch=balance_batch, + loss_name=cfg.loss_name, + loss_cfg=cfg.loss, + transfer=transfer, args=train_args, compute_metrics=compute_metrics, data_collator=DataCollatorWithPadding(tokenizer), diff --git a/w2s/sft_config.py b/w2s/sft_config.py index 15ff205..6a9b345 100644 --- a/w2s/sft_config.py +++ b/w2s/sft_config.py @@ -1,9 +1,28 @@ from dataclasses import dataclass from typing import Literal, Optional, Union -from simple_parsing import Serializable, field +from simple_parsing import Serializable, field, subgroups +@dataclass +class LossConfig(Serializable): + pass + +@dataclass +class LogConfidenceLossConfig(LossConfig): + logconf_weight: float = 0.5 + logconf_warmup_steps: int = 200 + balance_batch: bool = False + +@dataclass +class ConfidenceWindowLossConfig(LossConfig): + radius: float = 0.15 + +LOSS_CONFIGS = { + "logconf": LogConfidenceLossConfig, + "window": ConfidenceWindowLossConfig, +} + @dataclass class SFTConfig(Serializable): # TODO: what is this for?? # name of the model to train @@ -29,10 +48,7 @@ class SFTConfig(Serializable): # TODO: what is this for?? eval_every: int = 100 # steps save_every: int = 100 # steps save_total_limit: Optional[int] = None - logconf_weight: float = 0.5 - logconf_warmup_steps: int = 200 - balance_batch: bool = False - strong_weight: float = 0.5 + loss: LossConfig = subgroups(LOSS_CONFIGS, default="logconf") weight_decay: float = 0.1 weak_lr: float = 5e-4 strong_lr: float = 8e-5 @@ -40,6 +56,7 @@ class SFTConfig(Serializable): # TODO: what is this for?? metric_for_best_model: str = "val_auroc" greater_is_better: bool = field(init=False) + loss_name: str = field(init=False) def __post_init__(self): if "loss" in self.metric_for_best_model: @@ -52,6 +69,8 @@ def __post_init__(self): else: raise ValueError(f"Unknown metric {self.metric_for_best_model}") + self.loss_name = {LOSS_CONFIGS[k]:k for k in LOSS_CONFIGS}[type(self.loss)] + def to_dict(self): irrelevant_fields = ["results_folder", "run_name", "minibatch_size"] return {k: v for k, v in vars(self).items() if k not in irrelevant_fields}