Skip to content

Commit

Permalink
Log ensembled metrics (#215)
Browse files Browse the repository at this point in the history
* Log ensembled metrics

* Fixing pyright version
  • Loading branch information
norabelrose authored Apr 28, 2023
1 parent 025de02 commit 44c229c
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 43 deletions.
39 changes: 25 additions & 14 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,32 @@ def apply_to_layer(
for ds_name, (val_h, val_gt, _) in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

val_result = evaluate_preds(val_gt, reporter(val_h))
row_bufs["eval"].append({**meta, **val_result.to_dict()})
val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]
lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
lr_result = evaluate_preds(val_gt, model(val_h))
row_bufs["lr_eval"].append(
{"inlp_iter": i, **meta, **lr_result.to_dict()}
)
for i, model in enumerate(lr_models):
model.eval()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
36 changes: 27 additions & 9 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import asdict, dataclass
from typing import Literal

import torch
from einops import repeat
Expand Down Expand Up @@ -37,33 +38,50 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
else {}
)
auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()}
return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict}
return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict}


def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult:
def evaluate_preds(
y_true: Tensor,
y_logits: Tensor,
ensembling: Literal["none", "partial", "full"] = "none",
) -> EvalResult:
"""
Evaluate the performance of a classification model.
Args:
y_true: Ground truth tensor of shape (N,).
y_pred: Predicted class tensor of shape (N, variants, n_classes).
y_logits: Predicted class tensor of shape (N, variants, n_classes).
Returns:
dict: A dictionary containing the accuracy, AUROC, and ECE.
"""
(n, v, c) = y_logits.shape
assert y_true.shape == (n,)

# Clustered bootstrap confidence intervals for AUROC
y_true = repeat(y_true, "n -> n v", v=v)
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
acc = accuracy_ci(y_true, y_logits.argmax(dim=-1))

if ensembling == "full":
y_logits = y_logits.mean(dim=1)
else:
y_true = repeat(y_true, "n -> n v", v=v)

y_pred = y_logits.argmax(dim=-1)
if ensembling == "none":
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
elif ensembling in ("partial", "full"):
# Pool together the negative and positive class logits
if c == 2:
auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0])
else:
auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits)
else:
raise ValueError(f"Unknown mode: {ensembling}")

acc = accuracy_ci(y_true, y_pred)
cal_acc = None
cal_err = None

if c == 2:
pos_probs = y_logits.softmax(-1)[..., 1]
pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0])

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(y_true.float().mean())
Expand Down
2 changes: 1 addition & 1 deletion elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def apply_to_layers(
finally:
# Make sure the CSVs are written even if we crash or get interrupted
for name, dfs in df_buffers.items():
df = pd.concat(dfs).sort_values(by="layer")
df = pd.concat(dfs).sort_values(by=["layer", "ensembling"])
df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False)
if self.debug:
save_debug_log(self.datasets, self.out_dir)
47 changes: 29 additions & 18 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,35 @@ def apply_to_layer(
for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items():
meta = {"dataset": ds_name, "layer": layer}

val_result = evaluate_preds(val_gt, reporter(val_h))
row_bufs["eval"].append(
{
**meta,
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
**val_result.to_dict(),
}
)

if val_lm_preds is not None:
lm_result = evaluate_preds(val_gt, val_lm_preds)
row_bufs["lm_eval"].append({**meta, **lm_result.to_dict()})

for i, model in enumerate(lr_models):
lr_result = evaluate_preds(val_gt, model(val_h))
row_bufs["lr_eval"].append(
{"inlp_iter": i, **meta, **lr_result.to_dict()}
val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
"pseudo_auroc": pseudo_auroc,
"train_loss": train_loss,
}
)

if val_lm_preds is not None:
row_bufs["lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
}
)

for i, model in enumerate(lr_models):
row_bufs["lr_eval"].append(
{
**meta,
"ensembling": mode,
"inlp_iter": i,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dev = [
"hypothesis",
"pre-commit",
"pytest",
"pyright",
"pyright==1.1.304",
"scikit-learn",
]

Expand Down

0 comments on commit 44c229c

Please sign in to comment.