From 54f5ed3b6970a507b9372170b41c754e7c437c45 Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Mon, 17 Apr 2023 09:27:12 -0700 Subject: [PATCH] Fix accuracy computation in `Reporter` (#195) * acc now computes top-1 acc * fix auroc call --- elk/training/reporter.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 39a1deda..5e2767f5 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -94,7 +94,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: Returns: an instance of EvalResult containing the loss, accuracy, calibrated - accuracy, and AUROC of the probe on `hiddens`. + accuracy, and AUROC of the probe on `contrast_set`. + Accuracy: top-1 accuracy averaged over questions and variants. + Calibrated accuracy: top-1 accuracy averaged over questions and + variants, calibrated so that x% of the predictions are `True`, + where x is the proprtion of examples with ground truth label `True`. + AUROC: averaged over the n * v * c binary questions + ECE: Expected Calibration Error """ logits = self(hiddens) (_, v, c) = logits.shape @@ -116,11 +122,11 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: cal_acc = 0.0 cal_err = 0.0 - raw_preds = to_one_hot(logits.argmax(dim=-1), c).long() - Y = to_one_hot(Y, c).long().flatten() + Y_one_hot = to_one_hot(Y, c).long().flatten() + auroc_result = roc_auc_ci(Y_one_hot, logits.flatten()) + raw_preds = logits.argmax(dim=-1).long() raw_acc = accuracy(Y, raw_preds.flatten()) - auroc_result = roc_auc_ci(Y, logits.flatten()) return EvalResult( auroc=auroc_result.estimate, auroc_lower=auroc_result.lower,