diff --git a/w2s/loss.py b/w2s/loss.py index d778d71..49f2cea 100644 --- a/w2s/loss.py +++ b/w2s/loss.py @@ -27,6 +27,17 @@ def confidence_window_loss( return loss / logits.shape[0] +def cross_entropy_loss( + logits, + labels, +): + logits = logits.float() + labels = labels.float() + + target = torch.stack([1.0 - labels, labels], dim=1) + return torch.nn.functional.cross_entropy(logits, target) + + def log_confidence_loss( logits, labels, @@ -34,6 +45,7 @@ def log_confidence_loss( warmup_steps: int = 200, aux_coef: float = 0.5, balance_batch: bool = False, + harden: bool = True, ): """ This is similar to the loss in Burns et al., except that it also optionally @@ -51,13 +63,17 @@ def log_confidence_loss( coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef preds = torch.softmax(logits, dim=-1) - threshold = torch.quantile(preds[:, 0], prior) - strong_preds = torch.cat( - [(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]], - dim=1, - ) + if harden: + threshold = torch.quantile(preds[:, 0], prior) + target_preds = torch.cat( + [(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]], + dim=1, + ) + else: + target_preds = preds + labels_binary = torch.stack([1.0 - labels, labels], dim=1) - target = labels_binary * (1 - coef) + strong_preds.detach() * coef + target = labels_binary * (1 - coef) + target_preds.detach() * coef return torch.nn.functional.cross_entropy(logits, target) diff --git a/w2s/sft.py b/w2s/sft.py index 18fca96..640e339 100644 --- a/w2s/sft.py +++ b/w2s/sft.py @@ -11,7 +11,7 @@ ) import wandb -from w2s.loss import log_confidence_loss, confidence_window_loss +from w2s.loss import log_confidence_loss, confidence_window_loss, cross_entropy_loss from w2s.model import ModelConfig, init_model_and_tokenizer from w2s.roc_auc import roc_auc from w2s.sft_utils import ( @@ -49,6 +49,22 @@ def compute_loss(self, model, inputs, return_outputs=False): aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.), warmup_steps=self.loss_cfg.logconf_warmup_steps, balance_batch=self.loss_cfg.balance_batch, + harden=True, + ) + elif self.loss_name == 'entropy': + loss = log_confidence_loss( + outputs.logits, + labels, + self.state.global_step, + aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.), + warmup_steps=self.loss_cfg.logconf_warmup_steps, + balance_batch=self.loss_cfg.balance_batch, + harden=False, + ) + elif self.loss_name == 'xent': + loss = cross_entropy_loss( + outputs.logits, + labels, ) elif self.loss_name == 'window': loss = confidence_window_loss( diff --git a/w2s/sft_config.py b/w2s/sft_config.py index 95aff86..77e447c 100644 --- a/w2s/sft_config.py +++ b/w2s/sft_config.py @@ -25,9 +25,19 @@ class LogConfidenceLossConfig(LossConfig): class ConfidenceWindowLossConfig(LossConfig): radius: Union[float, literal("midweak")] = 0.15 +@dataclass +class LogEntropyLossConfig(LogConfidenceLossConfig): + pass + +@dataclass +class CrossEntropyLossConfig(LossConfig): + pass + LOSS_CONFIGS = { "logconf": LogConfidenceLossConfig, "window": ConfidenceWindowLossConfig, + "entropy": LogEntropyLossConfig, + "xent": CrossEntropyLossConfig, } @dataclass