Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster bootstrap for metrics; refactor metric computations into evaluate_preds #197

Merged
merged 13 commits into from
Apr 19, 2023
Prev Previous commit
Next Next commit
Fix normalization of LM logits
  • Loading branch information
norabelrose committed Apr 19, 2023
commit d625f7bb354256716cb4f62015cf3454e6d06c69
10 changes: 5 additions & 5 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict}


def evaluate_preds(y_true: Tensor, y_pred: Tensor) -> EvalResult:
def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult:
"""
Evaluate the performance of a classification model.

Expand All @@ -51,19 +51,19 @@ def evaluate_preds(y_true: Tensor, y_pred: Tensor) -> EvalResult:
Returns:
dict: A dictionary containing the accuracy, AUROC, and ECE.
"""
(n, v, c) = y_pred.shape
(n, v, c) = y_logits.shape
assert y_true.shape == (n,)

# Clustered bootstrap confidence intervals for AUROC
y_true = repeat(y_true, "n -> n v", v=v)
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_pred.flatten(1))
acc = accuracy_ci(y_true, y_pred.argmax(dim=-1))
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
acc = accuracy_ci(y_true, y_logits.argmax(dim=-1))

cal_acc = None
cal_err = None

if c == 2:
pos_probs = y_pred[..., 1].sigmoid()
pos_probs = y_logits.softmax(-1)[..., 1]

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(y_true.float().mean())
Expand Down