Skip to content

Commit

Permalink
probe only on transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 1, 2024
1 parent 1325975 commit a3051ce
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion w2s/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def __init__(self, config: TopoProbeConfig):
self.modified = config.modified

def fit(self, acts, labels):
self.labels = labels
self.acts = acts
self.labels = labels

def predict(self, acts):
return topolabel(self.acts, self.labels, acts, k_cc=self.k_cc, k_zeta=self.k_zeta)
Expand Down
2 changes: 1 addition & 1 deletion w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def compute_metrics(eval_pred):
acts = gather_hiddens(model, ds)
torch.save(acts, acts_dir / f"{name}.pt")

if cfg.probe_relabel or cfg.probe_filter:
if transfer and (cfg.probe_relabel or cfg.probe_filter):
print("Training probe")
acts = torch.load(acts_dir / f"train.pt", map_location=model.device)
probe = PROBES[cfg["probe_name"]](cfg["probe"])
Expand Down

0 comments on commit a3051ce

Please sign in to comment.