Skip to content

Commit

Permalink
remove code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
lauritowal committed Jul 28, 2023
1 parent c397fc2 commit 02c2a9c
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,16 @@ def __init__(

self.norm = None

if self.config.norm == "burns":
self.probe = nn.Sequential(
nn.Linear(
in_features,
1 if cfg.num_layers < 2 else hidden_size,
bias=cfg.bias,
device=device,
),
nn.Sigmoid(),
)
else:
self.probe = nn.Sequential(
nn.Linear(
in_features,
1 if cfg.num_layers < 2 else hidden_size,
bias=cfg.bias,
device=device,
)
)
self.probe = nn.Sequential(
nn.Linear(
in_features,
1 if cfg.num_layers < 2 else hidden_size,
bias=cfg.bias,
device=device,
),
*(nn.Sigmoid() if self.config.norm == "burns" else ()),
)

if cfg.pre_ln:
self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False))

Expand Down

0 comments on commit 02c2a9c

Please sign in to comment.