Skip to content

Commit

Permalink
cfg dict bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 4, 2024
1 parent 059b316 commit 6adbe82
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def compute_metrics(eval_pred):
predictions, labels = map(torch.from_numpy, eval_pred)
return compute_metrics_torch(predictions, labels)

probe_required = transfer and (cfg.probe_relabel or cfg.probe_filter)
probe_required = transfer and (cfg['probe_relabel'] or cfg['probe_filter'])

if save_activations or probe_required:
if acts_dir.exists():
Expand All @@ -156,19 +156,19 @@ def compute_metrics(eval_pred):
with open(save_dir / f"{name}_probe_metrics.json", "w") as f:
json.dump({"agree": agree_metrics, "gt": gt_metrics}, f, indent=2)
if name in ["train", "val"]:
if cfg.probe_filter:
good_indices = probe.filter(acts, torch.tensor(ds["labels"]), cfg.contamination)
if cfg['probe_filter']:
good_indices = probe.filter(acts, torch.tensor(ds["labels"]), cfg['contamination'])
sizes = {
"before": len(ds),
"after": len(good_indices),
"removed": len(ds) - len(good_indices),
"contamination": int(cfg.contamination * len(ds)),
"contamination": int(cfg['contamination'] * len(ds)),
}
with open(save_dir / f"{name}_filter_sizes.json", "w") as f:
json.dump(sizes, f, indent=2)
ds = ds.select(good_indices)
ds_dict[name] = ds
if cfg.probe_relabel:
if cfg['probe_relabel']:
ds = ds.remove_columns("labels").add_column("labels", preds.numpy())
ds_dict[name] = ds

Expand Down

0 comments on commit 6adbe82

Please sign in to comment.