Skip to content

Commit

Permalink
fix forward by checking for leace
Browse files Browse the repository at this point in the history
  • Loading branch information
lauritowal committed Jul 28, 2023
1 parent 43be2d1 commit 001deb8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@ def forward(self, x: Tensor) -> Tensor:
assert self.norm is not None, "Must call fit() before forward()"

raw_scores = self.probe(self.norm(x)).squeeze(-1)
return raw_scores
breakpoint()
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
if self.config.norm == "leace":
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
else:
return raw_scores

def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
"""Return the loss of the reporter on the contrast pair (x0, x1).
Expand Down

0 comments on commit 001deb8

Please sign in to comment.