Skip to content

Commit

Permalink
xent, entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 28, 2024
1 parent 685de51 commit 662b711
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
28 changes: 22 additions & 6 deletions w2s/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,25 @@ def confidence_window_loss(
return loss / logits.shape[0]


def cross_entropy_loss(
logits,
labels,
):
logits = logits.float()
labels = labels.float()

target = torch.stack([1.0 - labels, labels], dim=1)
return torch.nn.functional.cross_entropy(logits, target)


def log_confidence_loss(
logits,
labels,
step: int,
warmup_steps: int = 200,
aux_coef: float = 0.5,
balance_batch: bool = False,
harden: bool = True,
):
"""
This is similar to the loss in Burns et al., except that it also optionally
Expand All @@ -51,13 +63,17 @@ def log_confidence_loss(
coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef
preds = torch.softmax(logits, dim=-1)

threshold = torch.quantile(preds[:, 0], prior)
strong_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
)
if harden:
threshold = torch.quantile(preds[:, 0], prior)
target_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
)
else:
target_preds = preds

labels_binary = torch.stack([1.0 - labels, labels], dim=1)
target = labels_binary * (1 - coef) + strong_preds.detach() * coef
target = labels_binary * (1 - coef) + target_preds.detach() * coef
return torch.nn.functional.cross_entropy(logits, target)


Expand Down
18 changes: 17 additions & 1 deletion w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

import wandb
from w2s.loss import log_confidence_loss, confidence_window_loss
from w2s.loss import log_confidence_loss, confidence_window_loss, cross_entropy_loss
from w2s.model import ModelConfig, init_model_and_tokenizer
from w2s.roc_auc import roc_auc
from w2s.sft_utils import (
Expand Down Expand Up @@ -49,6 +49,22 @@ def compute_loss(self, model, inputs, return_outputs=False):
aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.),
warmup_steps=self.loss_cfg.logconf_warmup_steps,
balance_batch=self.loss_cfg.balance_batch,
harden=True,
)
elif self.loss_name == 'entropy':
loss = log_confidence_loss(
outputs.logits,
labels,
self.state.global_step,
aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.),
warmup_steps=self.loss_cfg.logconf_warmup_steps,
balance_batch=self.loss_cfg.balance_batch,
harden=False,
)
elif self.loss_name == 'xent':
loss = cross_entropy_loss(
outputs.logits,
labels,
)
elif self.loss_name == 'window':
loss = confidence_window_loss(
Expand Down
10 changes: 10 additions & 0 deletions w2s/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ class LogConfidenceLossConfig(LossConfig):
class ConfidenceWindowLossConfig(LossConfig):
radius: Union[float, literal("midweak")] = 0.15

@dataclass
class LogEntropyLossConfig(LogConfidenceLossConfig):
pass

@dataclass
class CrossEntropyLossConfig(LossConfig):
pass

LOSS_CONFIGS = {
"logconf": LogConfidenceLossConfig,
"window": ConfidenceWindowLossConfig,
"entropy": LogEntropyLossConfig,
"xent": CrossEntropyLossConfig,
}

@dataclass
Expand Down

0 comments on commit 662b711

Please sign in to comment.