forked from EleutherAI/elk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
reporter.py
137 lines (109 loc) · 4.21 KB
/
reporter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""An ELK reporter network."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, NamedTuple, Optional
import torch
import torch.nn as nn
from einops import rearrange, repeat
from simple_parsing.helpers import Serializable
from torch import Tensor
from ..calibration import CalibrationError
from ..metrics import accuracy, roc_auc_ci, to_one_hot
class EvalResult(NamedTuple):
"""The result of evaluating a reporter on a dataset.
The `.score()` function of a reporter returns an instance of this class,
which contains the loss, accuracy, calibrated accuracy, and AUROC.
"""
auroc: float
auroc_lower: float
auroc_upper: float
acc: float
cal_acc: float
ece: float
@dataclass
class ReporterConfig(Serializable):
"""
Args:
seed: The random seed to use. Defaults to 42.
"""
seed: int = 42
@dataclass
class OptimConfig(Serializable):
"""
Args:
lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.
Defaults to 1e-2.
num_epochs: The number of epochs to train for. Defaults to 1000.
num_tries: The number of times to try training the reporter. Defaults to 10.
optimizer: The optimizer to use. Defaults to "adam".
weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01.
"""
lr: float = 1e-2
num_epochs: int = 1000
num_tries: int = 10
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
weight_decay: float = 0.01
class Reporter(nn.Module, ABC):
"""An ELK reporter network."""
def reset_parameters(self):
"""Reset the parameters of the probe."""
# TODO: These methods will do something fancier in the future
@classmethod
def load(cls, path: Path | str):
"""Load a reporter from a file."""
return torch.load(path)
def save(self, path: Path | str):
# TODO: Save separate JSON and PT files for the reporter.
torch.save(self, path)
@abstractmethod
def fit(
self,
hiddens: Tensor,
labels: Optional[Tensor] = None,
) -> float:
...
@torch.no_grad()
def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult:
"""Score the probe on the contrast set `hiddens`.
Args:
labels: The labels of the contrast pair.
hiddens: Contrast set of shape [n, v, k, d].
Returns:
an instance of EvalResult containing the loss, accuracy, calibrated
accuracy, and AUROC of the probe on `contrast_set`.
Accuracy: top-1 accuracy averaged over questions and variants.
Calibrated accuracy: top-1 accuracy averaged over questions and
variants, calibrated so that x% of the predictions are `True`,
where x is the proprtion of examples with ground truth label `True`.
AUROC: averaged over the n * v * c binary questions
ECE: Expected Calibration Error
"""
logits = self(hiddens)
(_, v, c) = logits.shape
# makes `num_variants` copies of each label
logits = rearrange(logits, "n v c -> (n v) c")
Y = repeat(labels, "n -> (n v)", v=v).float()
if c == 2:
pos_probs = logits[..., 1].flatten().sigmoid()
cal_err = CalibrationError().update(Y.cpu(), pos_probs.cpu()).compute().ece
# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(labels.float().mean())
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = cal_preds.flatten().eq(Y).float().mean().item()
else:
# TODO: Implement calibration error for k > 2?
cal_acc = 0.0
cal_err = 0.0
Y_one_hot = to_one_hot(Y, c).long().flatten()
auroc_result = roc_auc_ci(Y_one_hot, logits.flatten())
raw_preds = logits.argmax(dim=-1).long()
raw_acc = accuracy(Y, raw_preds.flatten())
return EvalResult(
auroc=auroc_result.estimate,
auroc_lower=auroc_result.lower,
auroc_upper=auroc_result.upper,
acc=float(raw_acc),
cal_acc=cal_acc,
ece=cal_err,
)