diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 472417f5..7a55a858 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -9,6 +9,7 @@ import torch.nn as nn from concept_erasure import LeaceFitter from torch import Tensor +from typing_extensions import override from ..parsing import parse_loss from ..utils.typing import assert_type @@ -88,6 +89,7 @@ def __init__( num_variants: int = 1, ): super().__init__() + self.config = cfg self.in_features = in_features self.num_variants = num_variants @@ -128,6 +130,15 @@ def __init__( ) ) + @override + def parameters(self, recurse=True): + parameters = super(CcsReporter, self).parameters(recurse=recurse) + for param in parameters: + # exclude the platt scaling parameters + # kind of a hack for now, we should find probably a cleaner way + if param is not self.scale and param is not self.bias: + yield param + def reset_parameters(self): """Reset the parameters of the probe. @@ -164,15 +175,9 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" assert self.norm is not None, "Must call fit() before forward()" - raw_scores = self.probe(self.norm(x)).squeeze(-1) - if self.config.norm == "leace": - return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - - elif self.config.norm == "burns": - return raw_scores - else: - raise ValueError(f"Unknown normalization {self.config.norm}.") + platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1) + return platt_scaled_scores def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: """Return the loss of the reporter on the contrast pair (x0, x1). @@ -248,6 +253,7 @@ def fit(self, hiddens: Tensor) -> float: raise RuntimeError("Got NaN/infinite loss during training") self.load_state_dict(best_state) + return best_loss def train_loop_adam(self, x_neg: Tensor, x_pos: Tensor) -> float: diff --git a/elk/training/train.py b/elk/training/train.py index a7f0ef07..fb882240 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -82,13 +82,8 @@ def apply_to_layer( reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - - if not self.net.norm == "burns": - (_, v, k, _) = first_train_h.shape - reporter.platt_scale( - to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), - rearrange(first_train_h, "n v k d -> (n v k) d"), - ) + labels = repeat(to_one_hot(train_gt, k), "n k -> n v k", v=v) + reporter.platt_scale(labels, first_train_h) elif isinstance(self.net, EigenFitterConfig): fitter = EigenFitter(