Skip to content

Commit

Permalink
use hard labels for auroc and acc
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 20, 2024
1 parent c444ba3 commit e57e4f4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions underspec/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor:
)
if y_true.dim() not in (1, 2):
raise ValueError("y_true and y_pred should be 1D or 2D tensors")
if not ((y_true == 1) | (y_true == 0)).all():
raise ValueError("y_true should contain only 0s and 1s")

# Sort y_pred in descending order and get indices
indices = y_pred.argsort(descending=True, dim=-1)
Expand Down
5 changes: 3 additions & 2 deletions underspec/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def process(examples):

def compute_metrics(eval_pred):
predictions, labels = map(torch.from_numpy, eval_pred)
hard_labels = (labels > 0.5).long()
return dict(
accuracy=predictions.argmax(dim=1).eq(labels).float().mean(),
auroc=roc_auc(labels, predictions[:, 1]),
accuracy=predictions.argmax(dim=1).eq(hard_labels).float().mean(),
auroc=roc_auc(hard_labels, predictions[:, 1]),
)

trainer = CustomLossTrainer(
Expand Down

0 comments on commit e57e4f4

Please sign in to comment.