Skip to content

Commit

Permalink
Fix accuracy computation in Reporter (EleutherAI#195)
Browse files Browse the repository at this point in the history
* acc now computes top-1 acc

* fix auroc call
  • Loading branch information
AlexTMallen committed Apr 17, 2023
1 parent 633cda0 commit 54f5ed3
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
Returns:
an instance of EvalResult containing the loss, accuracy, calibrated
accuracy, and AUROC of the probe on `hiddens`.
accuracy, and AUROC of the probe on `contrast_set`.
Accuracy: top-1 accuracy averaged over questions and variants.
Calibrated accuracy: top-1 accuracy averaged over questions and
variants, calibrated so that x% of the predictions are `True`,
where x is the proprtion of examples with ground truth label `True`.
AUROC: averaged over the n * v * c binary questions
ECE: Expected Calibration Error
"""
logits = self(hiddens)
(_, v, c) = logits.shape
Expand All @@ -116,11 +122,11 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
cal_acc = 0.0
cal_err = 0.0

raw_preds = to_one_hot(logits.argmax(dim=-1), c).long()
Y = to_one_hot(Y, c).long().flatten()
Y_one_hot = to_one_hot(Y, c).long().flatten()
auroc_result = roc_auc_ci(Y_one_hot, logits.flatten())

raw_preds = logits.argmax(dim=-1).long()
raw_acc = accuracy(Y, raw_preds.flatten())
auroc_result = roc_auc_ci(Y, logits.flatten())
return EvalResult(
auroc=auroc_result.estimate,
auroc_lower=auroc_result.lower,
Expand Down

0 comments on commit 54f5ed3

Please sign in to comment.