Skip to content

Commit

Permalink
remove sigmoid from ccs reporter class + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lauritowal committed Aug 3, 2023
1 parent 99646e9 commit aeb8b09
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
17 changes: 7 additions & 10 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,14 @@ def __init__(
hidden_size = cfg.hidden_size or 4 * in_features // 3

self.norm = None

probe_layers: list[nn.Module] = [
self.probe: nn.Sequential(

Check failure on line 102 in elk/training/ccs_reporter.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Illegal type annotation: call expression not allowed (reportGeneralTypeIssues)
nn.Linear(
in_features,
1 if cfg.num_layers < 2 else hidden_size,
bias=cfg.bias,
device=device,
)
]

if self.config.norm == "burns":
probe_layers.append(nn.Sigmoid())

self.probe = nn.Sequential(*probe_layers)
),
)

if cfg.pre_ln:
self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False))
Expand Down Expand Up @@ -174,8 +168,11 @@ def forward(self, x: Tensor) -> Tensor:
raw_scores = self.probe(self.norm(x)).squeeze(-1)
if self.config.norm == "leace":
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
else:

elif self.config.norm == "burns":
return raw_scores
else:
raise ValueError(f"Unknown normalization {self.config.norm}.")

def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
"""Return the loss of the reporter on the contrast pair (x0, x1).
Expand Down
2 changes: 1 addition & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def apply_to_layer(
reporter = CcsReporter(self.net, d, device=device, num_variants=v)
train_loss = reporter.fit(first_train_h)

if self.net.norm == "leace":
if not self.net.norm == "burns":
(_, v, k, _) = first_train_h.shape
reporter.platt_scale(
to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
Expand Down

0 comments on commit aeb8b09

Please sign in to comment.