Skip to content

Commit

Permalink
final changes for repro runs
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 4, 2024
1 parent 57d1c6e commit 059b316
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 138 deletions.
2 changes: 2 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def get_model_and_run_name(model_name, current_name):
transfer=True,
predict_dict=w2s_predict_dict,
save_activations=True,
acts_dir=shared_root / cfg_name / "w2s" / "activations",
)

prev = "w2s"
Expand Down Expand Up @@ -183,6 +184,7 @@ def get_model_and_run_name(model_name, current_name):
cfg.to_dict(),
transfer=True,
predict_dict=s2s_predict_dict,
acts_dir=root / cfg_name / f"s2s-{s2s_iter}" / "activations",
)

prev = f"s2s-{s2s_iter}"
Expand Down
105 changes: 0 additions & 105 deletions run_batch.py

This file was deleted.

36 changes: 13 additions & 23 deletions run_eight.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import subprocess
from multiprocessing import Process
from sys import argv
import sys

# Define the datasets and respective GPU ids
# list of tuples with dataset name and minibatch size
Expand All @@ -22,22 +22,8 @@
"CUDA_VISIBLE_DEVICES={gpu_id} "
"python run.py "
"--dataset {dataset} "
"--weak_model_name Qwen/Qwen1.5-0.5B "
"--strong_model_name meta-llama/Meta-Llama-3-8B "
"--n_epochs 3 "
"--n_train 10_000 "
"--n_val 1000 "
"--n_test 5_000 "
"--n_predict 0 "
"--eval_every 100 "
"--save_every 100 "
"--save_total_limit 1 "
"--loss window "
"--radius .3 "
"--minibatch_size {minibatch_size} "
"--weak_lr 5e-4 "
"--strong_lr 8e-5 "
'--run_name "bigwindow" '
"{argv} "
)


Expand All @@ -46,14 +32,15 @@ def run_command(command):


if __name__ == "__main__":
# get GPU ID arguments
if len(argv) > 1:
included_gpu_ids = list(map(int, argv[1:]))
assert all(
gpu_id in gpu_ids for gpu_id in included_gpu_ids
), f"Invalid GPU IDs: {included_gpu_ids}"
if "--gpus" in sys.argv:
included_gpu_ids = [int(gpu_id) for gpu_id in argv[argv.index("--gpus") + 1].split(",")]
other_argv = sys.argv[1:argv.index("--gpus")] + sys.argv[argv.index("--gpus") + 2:]

else:
included_gpu_ids = gpu_ids
other_argv = sys.argv[1:]

argv = " ".join(other_argv)

# List to hold processes
processes = []
Expand All @@ -63,7 +50,10 @@ def run_command(command):
if gpu_id not in included_gpu_ids:
continue
command = base_command.format(
gpu_id=gpu_id, dataset=dataset, minibatch_size=minibatch_size
gpu_id=gpu_id,
dataset=dataset,
minibatch_size=minibatch_size,
argv=argv,
)
print(f"Running command: {command}") # Debug print
p = Process(target=run_command, args=(command,))
Expand Down
2 changes: 1 addition & 1 deletion w2s/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def to_dict(self):
@dataclass
class LogConfidenceLossConfig(LossConfig):
logconf_weight: float = 0.5
logconf_warmup_steps: int = 200
logconf_warmup_steps: int = 100
balance_batch: bool = False

@dataclass
Expand Down
2 changes: 1 addition & 1 deletion w2s/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class KnnProbeConfig(ProbeConfig):

@dataclass
class LogisticProbeConfig(ProbeConfig):
l2p: float = 0.01
l2p: float = 1e-3

@dataclass
class TopoProbeConfig(ProbeConfig):
Expand Down
15 changes: 9 additions & 6 deletions w2s/sft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Union
from typing import Union, Optional

import torch
from datasets import DatasetDict
Expand Down Expand Up @@ -89,6 +89,7 @@ def train(
predict_dict: Union[DatasetDict, dict, None] = None,
save_activations: bool = False,
use_probe: bool = False,
acts_dir: Optional[Path] = None,
):
"""
ds_dict: DatasetDict with splits for train, val, test, and (optionally) predict,
Expand All @@ -107,7 +108,6 @@ def train(
"""
save_dir = Path(train_args.output_dir)
results_path = save_dir / "results.json"
acts_dir = save_dir / "activations"

clear_mem()
print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use before training")
Expand All @@ -131,7 +131,9 @@ def compute_metrics(eval_pred):
predictions, labels = map(torch.from_numpy, eval_pred)
return compute_metrics_torch(predictions, labels)

if save_activations:
probe_required = transfer and (cfg.probe_relabel or cfg.probe_filter)

if save_activations or probe_required:
if acts_dir.exists():
print("Activations already exist at", acts_dir)
else:
Expand All @@ -141,7 +143,7 @@ def compute_metrics(eval_pred):
acts = gather_hiddens(model, ds)
torch.save(acts, acts_dir / f"{name}.pt")

if transfer and (cfg.probe_relabel or cfg.probe_filter):
if probe_required:
print("Training probe")
acts = torch.load(acts_dir / f"train.pt", map_location=model.device)
probe = PROBES[cfg["probe_name"]](cfg["probe"])
Expand All @@ -151,7 +153,7 @@ def compute_metrics(eval_pred):
preds = probe.predict(acts)
agree_metrics = compute_metrics_torch(preds, torch.tensor(ds["labels"]))
gt_metrics = compute_metrics_torch(preds, torch.tensor(ds["gt_labels"]))
with open(acts_dir / f"{name}_probe_metrics.json", "w") as f:
with open(save_dir / f"{name}_probe_metrics.json", "w") as f:
json.dump({"agree": agree_metrics, "gt": gt_metrics}, f, indent=2)
if name in ["train", "val"]:
if cfg.probe_filter:
Expand All @@ -162,7 +164,7 @@ def compute_metrics(eval_pred):
"removed": len(ds) - len(good_indices),
"contamination": int(cfg.contamination * len(ds)),
}
with open(acts_dir / f"{name}_filter_sizes.json", "w") as f:
with open(save_dir / f"{name}_filter_sizes.json", "w") as f:
json.dump(sizes, f, indent=2)
ds = ds.select(good_indices)
ds_dict[name] = ds
Expand Down Expand Up @@ -211,6 +213,7 @@ def compute_metrics(eval_pred):
cfg["train_args"] = train_args.to_dict()
cfg["transfer"] = transfer
cfg["loss"] = cfg["loss"].to_dict()
cfg["probe"] = cfg["probe"].to_dict()
json.dump(cfg, f, indent=2)
wandb.config.update(cfg)

Expand Down
11 changes: 9 additions & 2 deletions w2s/sft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
# literal = lambda *args: StrEnum("option", args)

# Python 3.10 version:
def literal(s: str):
return type(f'LiteralString_{s}', (LiteralString,), {"value": s})
def ident_escape_char(c: str) -> str:
if c.isalnum() or c == "_":
return c
return f"_{ord(c)}_"

def ident_escape(s: str) -> str:
return "".join(ident_escape_char(c) for c in s)

def literal(s: str):
return type('LiteralString_' + ident_escape(s), (LiteralString,), {"value": s})

class LiteralString():
value = ""
Expand Down

0 comments on commit 059b316

Please sign in to comment.