diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 7ccb156d..fa350b18 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -9,9 +9,9 @@ from ..extraction.extraction import Extract from ..files import elk_reporter_dir +from ..metrics import evaluate_preds from ..run import Run from ..training import Reporter -from ..training.supervised import evaluate_supervised from ..utils import select_usable_devices @@ -70,30 +70,25 @@ def evaluate_reporter( row_buf = [] for ds_name, (val_h, val_gt, _) in val_output.items(): - val_result = reporter.score(val_gt, val_h) + val_result = evaluate_preds(val_gt, reporter(val_h)) - stats_row = pd.Series( - { - "dataset": ds_name, - "layer": layer, - **val_result._asdict(), - } - ) + stats_row = { + "dataset": ds_name, + "layer": layer, + **val_result.to_dict(), + } lr_dir = experiment_dir / "lr_models" if not self.cfg.skip_supervised and lr_dir.exists(): with open(lr_dir / f"layer_{layer}.pt", "rb") as f: lr_model = torch.load(f, map_location=device).eval() - lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt) - stats_row["lr_auroc"] = lr_auroc_res.estimate - stats_row["lr_auroc_lower"] = lr_auroc_res.lower - stats_row["lr_auroc_upper"] = lr_auroc_res.upper - stats_row["lr_acc"] = lr_acc + lr_result = evaluate_preds(val_gt, lr_model(val_h)) + stats_row.update(lr_result.to_dict(prefix="lr_")) row_buf.append(stats_row) - return pd.DataFrame(row_buf) + return pd.DataFrame.from_records(row_buf) def evaluate(self): """Evaluate the reporter on all layers.""" diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 2f9bda09..7d22f9ab 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -149,7 +149,7 @@ def extract_hiddens( ) for layer_idx in layer_indices } - lm_preds = torch.empty( + lm_logits = torch.empty( num_variants, num_choices, device=device, @@ -205,13 +205,13 @@ def extract_hiddens( dim=-1 ) tokens = inputs.input_ids[..., start:end, None] - lm_preds[i, j] = log_p.gather(-1, tokens).sum() + lm_logits[i, j] = log_p.gather(-1, tokens).sum() elif isinstance(outputs, Seq2SeqLMOutput): # The cross entropy loss is averaged over tokens, so we need to # multiply by the length to get the total log probability. length = inputs.labels.shape[-1] - lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length + lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] @@ -244,7 +244,7 @@ def extract_hiddens( **hidden_dict, ) if has_lm_preds: - out_record["model_preds"] = lm_preds.softmax(dim=-1) + out_record["model_logits"] = lm_logits yield out_record @@ -319,9 +319,9 @@ def get_splits() -> SplitDict: ), } - # Only add model_preds if the model is an autoregressive model + # Only add model_logits if the model is an autoregressive model if is_autoregressive(model_cfg): - other_cols["model_preds"] = Array2D( + other_cols["model_logits"] = Array2D( shape=(num_variants, num_classes), dtype="float32", ) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index b4a94617..c56eeffd 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -49,7 +49,7 @@ class PromptConfig(Serializable): datasets: list[str] = field(positional=True) data_dirs: list[str] = field(default_factory=list) label_columns: list[str] = field(default_factory=list) - max_examples: list[int] = field(default_factory=lambda: [750, 250]) + max_examples: list[int] = field(default_factory=lambda: [1000, 1000]) num_classes: int = 0 num_shots: int = 0 num_variants: int = -1 diff --git a/elk/metrics/__init__.py b/elk/metrics/__init__.py new file mode 100644 index 00000000..7fb21450 --- /dev/null +++ b/elk/metrics/__init__.py @@ -0,0 +1,16 @@ +from .accuracy import accuracy_ci +from .calibration import CalibrationError, CalibrationEstimate +from .eval import EvalResult, evaluate_preds, to_one_hot +from .roc_auc import RocAucResult, roc_auc, roc_auc_ci + +__all__ = [ + "accuracy_ci", + "CalibrationError", + "CalibrationEstimate", + "EvalResult", + "evaluate_preds", + "roc_auc", + "roc_auc_ci", + "to_one_hot", + "RocAucResult", +] diff --git a/elk/metrics/accuracy.py b/elk/metrics/accuracy.py new file mode 100644 index 00000000..33b94632 --- /dev/null +++ b/elk/metrics/accuracy.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor + + +@dataclass(frozen=True) +class AccuracyResult: + """Accuracy point estimate and confidence interval.""" + + estimate: float + """Point estimate of the accuracy computed on this sample.""" + lower: float + """Lower bound of the confidence interval.""" + upper: float + """Upper bound of the confidence interval.""" + + +def accuracy_ci( + y_true: Tensor, + y_pred: Tensor, + *, + num_samples: int = 1000, + level: float = 0.95, + seed: int = 42, +) -> AccuracyResult: + """Bootstrap confidence interval for accuracy, with optional clustering. + + When the input arguments are 2D, this function performs the cluster bootstrap, + resampling clusters with replacement instead of individual samples. The first + axis is assumed to be the cluster axis. + + Args: + y_true: Ground truth tensor of shape `(N,)` or `(N, cluster_size)`. + y_pred: Predicted class tensor of shape `(N,)` or `(N, cluster_size)`. + num_samples (int): Number of bootstrap samples to use. + level (float): Confidence level of the confidence interval. + seed (int): Random seed for reproducibility. + + Returns: + RocAucResult: Named tuple containing the lower and upper bounds of the + confidence interval, along with the point estimate. + """ + if torch.is_floating_point(y_pred) or torch.is_floating_point(y_true): + raise TypeError("y_true and y_pred should be integer tensors") + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true and y_pred should have the same shape; " + f"got {y_true.shape} and {y_pred.shape}" + ) + if y_true.dim() not in (1, 2): + raise ValueError("y_true and y_pred should be 1D or 2D tensors") + + # Either the number of samples (1D) or the number of clusters (2D) + N = y_true.shape[0] + device = y_true.device + + # Generate random indices for bootstrap samples (shape: [num_bootstraps, N]) + rng = torch.Generator(device=device).manual_seed(seed) + indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng) + + # Create bootstrap samples of true labels and predicted probabilities + y_true_bootstraps = y_true[indices] + y_pred_bootstraps = y_pred[indices] + + # Compute ROC AUC scores for bootstrap samples. If the inputs were 2D, the + # bootstrapped tensors are now 3D [num_bootstraps, N, cluster_size], so we + # call flatten(1) to get a 2D tensor [num_bootstraps, N * cluster_size]. + bootstrap_hits = y_true_bootstraps.flatten(1).eq(y_pred_bootstraps.flatten(1)) + bootstrap_accs = bootstrap_hits.float().mean(1) + + # Calculate the lower and upper bounds of the confidence interval. We use + # nanquantile instead of quantile because some bootstrap samples may have + # NaN values due to the fact that they have only one class. + alpha = (1 - level) / 2 + q = bootstrap_accs.new_tensor([alpha, 1 - alpha]) + lower, upper = bootstrap_accs.nanquantile(q).tolist() + + # Compute the point estimate. Call flatten to ensure that we get a single number + # computed across cluster boundaries even if the inputs were clustered. + estimate = y_true.flatten().eq(y_pred.flatten()).float().mean().item() + return AccuracyResult(estimate, lower, upper) diff --git a/elk/calibration.py b/elk/metrics/calibration.py similarity index 96% rename from elk/calibration.py rename to elk/metrics/calibration.py index 23a48b8e..705f02b7 100644 --- a/elk/calibration.py +++ b/elk/metrics/calibration.py @@ -1,12 +1,12 @@ import warnings from dataclasses import dataclass, field -from typing import NamedTuple import torch from torch import Tensor -class CalibrationEstimate(NamedTuple): +@dataclass(frozen=True) +class CalibrationEstimate: ece: float num_bins: int @@ -82,7 +82,7 @@ def compute(self, p: int = 2) -> CalibrationEstimate: # Split into (nearly) equal mass bins. They won't be exactly equal, so we # still weight the bins by their size. conf_bins = pred_probs.tensor_split(b_star) - w = torch.tensor([len(c) / n for c in conf_bins]) + w = pred_probs.new_tensor([len(c) / n for c in conf_bins]) # See the definition of ECE_sweep in Equation 8 of Roelofs et al. (2020) mean_confs = torch.stack([c.mean() for c in conf_bins]) diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py new file mode 100644 index 00000000..dcc5ce35 --- /dev/null +++ b/elk/metrics/eval.py @@ -0,0 +1,91 @@ +from dataclasses import asdict, dataclass + +import torch +from einops import repeat +from torch import Tensor + +from .accuracy import AccuracyResult, accuracy_ci +from .calibration import CalibrationError, CalibrationEstimate +from .roc_auc import RocAucResult, roc_auc_ci + + +@dataclass(frozen=True) +class EvalResult: + """The result of evaluating a classifier.""" + + accuracy: AccuracyResult + """Top 1 accuracy, implemented for both binary and multi-class classification.""" + cal_accuracy: AccuracyResult | None + """Calibrated accuracy, only implemented for binary classification.""" + calibration: CalibrationEstimate | None + """Expected calibration error, only implemented for binary classification.""" + roc_auc: RocAucResult + """Area under the ROC curve. For multi-class classification, each class is treated + as a one-vs-rest binary classification problem.""" + + def to_dict(self, prefix: str = "") -> dict[str, float]: + """Convert the result to a dictionary.""" + acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()} + cal_acc_dict = ( + {f"{prefix}cal_acc_{k}": v for k, v in asdict(self.cal_accuracy).items()} + if self.cal_accuracy is not None + else {} + ) + cal_dict = ( + {f"{prefix}ece": self.calibration.ece} + if self.calibration is not None + else {} + ) + auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()} + return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict} + + +def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult: + """ + Evaluate the performance of a classification model. + + Args: + y_true: Ground truth tensor of shape (N,). + y_pred: Predicted class tensor of shape (N, variants, n_classes). + + Returns: + dict: A dictionary containing the accuracy, AUROC, and ECE. + """ + (n, v, c) = y_logits.shape + assert y_true.shape == (n,) + + # Clustered bootstrap confidence intervals for AUROC + y_true = repeat(y_true, "n -> n v", v=v) + auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) + acc = accuracy_ci(y_true, y_logits.argmax(dim=-1)) + + cal_acc = None + cal_err = None + + if c == 2: + pos_probs = y_logits.softmax(-1)[..., 1] + + # Calibrated accuracy + cal_thresh = pos_probs.float().quantile(y_true.float().mean()) + cal_preds = pos_probs.gt(cal_thresh).to(torch.int) + cal_acc = accuracy_ci(y_true, cal_preds) + + cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) + cal_err = cal.compute() + + return EvalResult(acc, cal_acc, cal_err, auroc) + + +def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: + """ + Convert a tensor of class labels to a one-hot representation. + + Args: + labels (Tensor): A tensor of class labels of shape (N,). + n_classes (int): The total number of unique classes. + + Returns: + Tensor: A one-hot representation tensor of shape (N, n_classes). + """ + one_hot_labels = labels.new_zeros(*labels.shape, n_classes) + return one_hot_labels.scatter_(-1, labels.unsqueeze(-1).long(), 1) diff --git a/elk/metrics.py b/elk/metrics/roc_auc.py similarity index 70% rename from elk/metrics.py rename to elk/metrics/roc_auc.py index 0150f02f..1efce933 100644 --- a/elk/metrics.py +++ b/elk/metrics/roc_auc.py @@ -1,46 +1,12 @@ -from typing import NamedTuple +from dataclasses import dataclass import torch from torch import Tensor -def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: - """ - Convert a tensor of class labels to a one-hot representation. - - Args: - labels (Tensor): A tensor of class labels of shape (N,). - n_classes (int): The total number of unique classes. - - Returns: - Tensor: A one-hot representation tensor of shape (N, n_classes). - """ - one_hot_labels = labels.new_zeros(labels.size(0), n_classes) - return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1) - - -def accuracy(y_true: Tensor, y_pred: Tensor) -> float: - """ - Compute the accuracy of a classification model. - - Args: - y_true: Ground truth tensor of shape (N,). - y_pred: Predicted class tensor of shape (N,) or (N, n_classes). - - Returns: - float: Accuracy of the model. - """ - # Check if binary or multi-class classification - if len(y_pred.shape) == 1: - hard_preds = y_pred > 0.5 - else: - hard_preds = y_pred.argmax(-1) - - return hard_preds.cpu().eq(y_true.cpu()).float().mean().item() - - -class RocAucResult(NamedTuple): - """Named tuple for storing ROC AUC results.""" +@dataclass(frozen=True) +class RocAucResult: + """Dataclass for storing ROC AUC results.""" estimate: float """Point estimate of the ROC AUC computed on this sample.""" @@ -111,11 +77,15 @@ def roc_auc_ci( level: float = 0.95, seed: int = 42, ) -> RocAucResult: - """Bootstrap confidence interval for the ROC AUC. + """Bootstrap confidence interval for the ROC AUC, with optional clustering. + + When the input arguments are 2D, this function performs the cluster bootstrap, + resampling clusters with replacement instead of individual samples. The first + axis is assumed to be the cluster axis. Args: - y_true: Ground truth tensor of shape `(N,)`. - y_pred: Predicted class tensor of shape `(N,)`. + y_true: Ground truth tensor of shape `(N,)` or `(N, cluster_size)`. + y_pred: Predicted class tensor of shape `(N,)` or `(N, cluster_size)`. num_samples (int): Number of bootstrap samples to use. level (float): Confidence level of the confidence interval. seed (int): Random seed for reproducibility. @@ -129,11 +99,12 @@ def roc_auc_ci( f"y_true and y_pred should have the same shape; " f"got {y_true.shape} and {y_pred.shape}" ) - if y_true.dim() != 1: - raise ValueError("y_true and y_pred should be 1D tensors") + if y_true.dim() not in (1, 2): + raise ValueError("y_true and y_pred should be 1D or 2D tensors") - device = y_true.device + # Either the number of samples (1D) or the number of clusters (2D) N = y_true.shape[0] + device = y_true.device # Generate random indices for bootstrap samples (shape: [num_bootstraps, N]) rng = torch.Generator(device=device).manual_seed(seed) @@ -143,16 +114,19 @@ def roc_auc_ci( y_true_bootstraps = y_true[indices] y_pred_bootstraps = y_pred[indices] - # Compute ROC AUC scores for bootstrap samples - bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps) + # Compute ROC AUC scores for bootstrap samples. If the inputs were 2D, the + # bootstrapped tensors are now 3D [num_bootstraps, N, cluster_size], so we + # call flatten(1) to get a 2D tensor [num_bootstraps, N * cluster_size]. + bootstrap_aucs = roc_auc(y_true_bootstraps.flatten(1), y_pred_bootstraps.flatten(1)) # Calculate the lower and upper bounds of the confidence interval. We use # nanquantile instead of quantile because some bootstrap samples may have # NaN values due to the fact that they have only one class. alpha = (1 - level) / 2 - q = y_pred.new_tensor([alpha, 1 - alpha]) + q = bootstrap_aucs.new_tensor([alpha, 1 - alpha]) lower, upper = bootstrap_aucs.nanquantile(q).tolist() - # Compute the point estimate - estimate = roc_auc(y_true, y_pred).item() + # Compute the point estimate. Call flatten to ensure that we get a single number + # computed across cluster boundaries even if the inputs were clustered. + estimate = roc_auc(y_true.flatten(), y_pred.flatten()).item() return RocAucResult(estimate, lower, upper) diff --git a/elk/run.py b/elk/run.py index bc889baa..6681d6f2 100644 --- a/elk/run.py +++ b/elk/run.py @@ -99,11 +99,11 @@ def prepare_data( val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) with split.formatted_as("torch", device=device): - has_preds = "model_preds" in split.features - lm_preds = split["model_preds"] if has_preds else None + has_preds = "model_logits" in split.features + lm_preds = split["model_logits"] if has_preds else None ds_name = get_dataset_name(ds) - out[ds_name] = (val_h, labels, lm_preds) + out[ds_name] = (val_h, labels.to(val_h.device), lm_preds) return out @@ -148,6 +148,6 @@ def apply_to_layers( # Make sure the CSV is written even if we crash or get interrupted if df_buf: df = pd.concat(df_buf).sort_values(by="layer") - df.to_csv(f, index=False) + df.round(4).to_csv(f, index=False) if self.cfg.debug: save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 5e2767f5..c10b9562 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -3,33 +3,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Literal, NamedTuple, Optional +from typing import Literal, Optional import torch import torch.nn as nn -from einops import rearrange, repeat from simple_parsing.helpers import Serializable from torch import Tensor -from ..calibration import CalibrationError -from ..metrics import accuracy, roc_auc_ci, to_one_hot - - -class EvalResult(NamedTuple): - """The result of evaluating a reporter on a dataset. - - The `.score()` function of a reporter returns an instance of this class, - which contains the loss, accuracy, calibrated accuracy, and AUROC. - """ - - auroc: float - auroc_lower: float - auroc_upper: float - - acc: float - cal_acc: float - ece: float - @dataclass class ReporterConfig(Serializable): @@ -83,55 +63,3 @@ def fit( labels: Optional[Tensor] = None, ) -> float: ... - - @torch.no_grad() - def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: - """Score the probe on the contrast set `hiddens`. - - Args: - labels: The labels of the contrast pair. - hiddens: Contrast set of shape [n, v, k, d]. - - Returns: - an instance of EvalResult containing the loss, accuracy, calibrated - accuracy, and AUROC of the probe on `contrast_set`. - Accuracy: top-1 accuracy averaged over questions and variants. - Calibrated accuracy: top-1 accuracy averaged over questions and - variants, calibrated so that x% of the predictions are `True`, - where x is the proprtion of examples with ground truth label `True`. - AUROC: averaged over the n * v * c binary questions - ECE: Expected Calibration Error - """ - logits = self(hiddens) - (_, v, c) = logits.shape - - # makes `num_variants` copies of each label - logits = rearrange(logits, "n v c -> (n v) c") - Y = repeat(labels, "n -> (n v)", v=v).float() - - if c == 2: - pos_probs = logits[..., 1].flatten().sigmoid() - cal_err = CalibrationError().update(Y.cpu(), pos_probs.cpu()).compute().ece - - # Calibrated accuracy - cal_thresh = pos_probs.float().quantile(labels.float().mean()) - cal_preds = pos_probs.gt(cal_thresh).to(torch.int) - cal_acc = cal_preds.flatten().eq(Y).float().mean().item() - else: - # TODO: Implement calibration error for k > 2? - cal_acc = 0.0 - cal_err = 0.0 - - Y_one_hot = to_one_hot(Y, c).long().flatten() - auroc_result = roc_auc_ci(Y_one_hot, logits.flatten()) - - raw_preds = logits.argmax(dim=-1).long() - raw_acc = accuracy(Y, raw_preds.flatten()) - return EvalResult( - auroc=auroc_result.estimate, - auroc_lower=auroc_result.lower, - auroc_upper=auroc_result.upper, - acc=float(raw_acc), - cal_acc=cal_acc, - ece=cal_err, - ) diff --git a/elk/training/supervised.py b/elk/training/supervised.py index 300d6f0e..86531106 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -1,30 +1,10 @@ import torch from einops import rearrange, repeat -from torch import Tensor -from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot -from ..utils import assert_type +from ..metrics import to_one_hot from .classifier import Classifier -def evaluate_supervised( - lr_model: Classifier, val_h: Tensor, val_labels: Tensor -) -> tuple[RocAucResult, float]: - (_, v, k, _) = val_h.shape - - with torch.no_grad(): - logits = rearrange(lr_model(val_h).squeeze(), "n v k -> (n v) k") - raw_preds = to_one_hot(logits.argmax(dim=-1), k).long() - - labels = repeat(val_labels, "n -> (n v)", v=v) - labels = to_one_hot(labels, k).flatten() - - lr_acc = accuracy(labels, raw_preds.flatten()) - lr_auroc = roc_auc_ci(labels, logits.flatten()) - - return lr_auroc, assert_type(float, lr_acc) - - def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier: Xs, train_labels = [], [] diff --git a/elk/training/sweep.py b/elk/training/sweep.py index 7f2bef2a..17250def 100644 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dataclasses import InitVar, dataclass from ..extraction import Extract, PromptConfig @@ -18,6 +19,14 @@ class Sweep: name: str | None = None + # A bit of a hack to add all the command line arguments from Elicit + run_template: Elicit = Elicit( + data=Extract( + model="", + prompts=PromptConfig(datasets=[""]), + ) + ) + def __post_init__(self, add_pooled: bool): if not self.datasets: raise ValueError("No datasets specified") @@ -49,12 +58,9 @@ def execute(self): # plus signs. This means we can pool datasets together inside of a # single sweep. datasets = [ds.strip() for ds in dataset_str.split("+")] - Elicit( - data=Extract( - model=model_str, - prompts=PromptConfig( - datasets=datasets, - ), - ), - out_dir=out_dir, - ).execute() + + run = deepcopy(self.run_template) + run.data.model = model_str + run.data.prompts.datasets = datasets + run.out_dir = out_dir + run.execute() diff --git a/elk/training/train.py b/elk/training/train.py index cd9dcb7c..19ddba29 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -11,9 +11,9 @@ from simple_parsing import Serializable, field, subgroups from ..extraction.extraction import Extract -from ..metrics import accuracy, roc_auc_ci, to_one_hot +from ..metrics import evaluate_preds, to_one_hot from ..run import Run -from ..training.supervised import evaluate_supervised, train_supervised +from ..training.supervised import train_supervised from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig @@ -146,41 +146,26 @@ def train_reporter( row_buf = [] for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items(): - val_result = reporter.score(val_gt, val_h) - row = pd.Series( - { - "dataset": ds_name, - "layer": layer, - "pseudo_auroc": pseudo_auroc, - "train_loss": train_loss, - **val_result._asdict(), - } - ) + val_result = evaluate_preds(val_gt, reporter(val_h)) + row = { + "dataset": ds_name, + "layer": layer, + "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + **val_result.to_dict(), + } if val_lm_preds is not None: - (_, v, k, _) = val_h.shape - - val_gt_rep = repeat(val_gt, "n -> (n v)", v=v) - val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") - val_lm_auroc_res = roc_auc_ci( - to_one_hot(val_gt_rep, k).long().flatten(), val_lm_preds.flatten() - ) - row["lm_auroc"] = val_lm_auroc_res.estimate - row["lm_auroc_lower"] = val_lm_auroc_res.lower - row["lm_auroc_upper"] = val_lm_auroc_res.upper - row["lm_acc"] = accuracy(val_gt_rep, val_lm_preds) + lm_result = evaluate_preds(val_gt, val_lm_preds) + row.update(lm_result.to_dict(prefix="lm_")) if lr_model is not None: - lr_auroc_res, row["lr_acc"] = evaluate_supervised( - lr_model, val_h, val_gt - ) - row["lr_auroc"] = lr_auroc_res.estimate - row["lr_auroc_lower"] = lr_auroc_res.lower - row["lr_auroc_upper"] = lr_auroc_res.upper + lr_result = evaluate_preds(val_gt, lr_model(val_h)) + row.update(lr_result.to_dict(prefix="lr_")) row_buf.append(row) - return pd.DataFrame(row_buf) + return pd.DataFrame.from_records(row_buf) def train(self): """Train a reporter on each layer of the network.""" diff --git a/tests/test_roc_auc.py b/tests/test_metrics.py similarity index 70% rename from tests/test_roc_auc.py rename to tests/test_metrics.py index 244bdb88..7e23aa4c 100644 --- a/tests/test_roc_auc.py +++ b/tests/test_metrics.py @@ -1,13 +1,16 @@ +import math + import numpy as np import torch from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score +from torch.distributions.normal import Normal -from elk.metrics import roc_auc +from elk.metrics import accuracy_ci, roc_auc -def test_roc_auc_score(): +def test_auroc_and_acc(): # Generate 1D binary classification dataset X_1d, y_true_1d = make_classification(n_samples=1000, random_state=42) @@ -51,3 +54,24 @@ def test_roc_auc_score(): # Assert that the results from the two implementations are almost equal np.testing.assert_almost_equal(roc_auc_1d_torch, roc_auc_1d_sklearn) np.testing.assert_almost_equal(roc_auc_2d_torch, roc_auc_2d_sklearn) + + ### Test accuracy_ci function ### + # Compute accuracy confidence interval + level = 0.95 + hard_preds = y_scores_1d_torch > 0.5 + acc_ci = accuracy_ci(y_true_1d_torch, hard_preds, level=level) + + # Point estimate of the accuracy + acc = hard_preds.eq(y_true_1d_torch).float().mean() + + # Compute the CI quantiles + alpha = (1 - level) / 2 + q = acc.new_tensor([alpha, 1 - alpha]) + + # Normal approximation to the binomial distribution + stderr = (acc * (1 - acc) / len(y_true_1d_torch)) ** 0.5 + lower, upper = Normal(acc, stderr).icdf(q).tolist() + + # Assert that the results from the two implementations are close + assert math.isclose(acc_ci.lower, lower, rel_tol=2e-3) + assert math.isclose(acc_ci.upper, upper, rel_tol=2e-3)