Skip to content

Commit

Permalink
run config window4
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 24, 2024
1 parent 6250719 commit e424073
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions run_eight.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import subprocess
from multiprocessing import Process
from sys import argv

# Define the datasets and respective GPU ids
configs = [
Expand All @@ -22,42 +23,51 @@
"--dataset {dataset} "
"--weak_model_name Qwen/Qwen1.5-0.5B "
"--strong_model_name meta-llama/Meta-Llama-3-8B "
"--n_epochs 2 "
"--n_epochs 3 "
"--n_train 10_000 "
"--n_val 1000 "
"--n_test 5_000 "
"--n_predict 0 "
"--eval_every 100 "
"--save_every 100 "
"--save_total_limit 1 "
"--loss logconf "
"--logconf_warmup_steps 80 "
"--balance_batch "
"--logconf_weight 0.5 "
"--loss window "
"--minibatch_size {minibatch_size} "
"--weak_lr 5e-4 "
"--strong_lr 8e-5 "
'--run_name "basic_w2s" '
'--run_name "window4" '
)


def run_command(command):
subprocess.run(command, shell=True, check=True)


# List to hold processes
processes = []
if __name__ == "__main__":
# get GPU ID arguments
if len(argv) > 1:
included_gpu_ids = list(map(int, argv[1:]))
assert all(
gpu_id in gpu_ids for gpu_id in included_gpu_ids
), f"Invalid GPU IDs: {included_gpu_ids}"
else:
included_gpu_ids = gpu_ids

# Loop over datasets and gpu_ids
for (dataset, minibatch_size), gpu_id in zip(configs, gpu_ids):
command = base_command.format(
gpu_id=gpu_id, dataset=dataset, minibatch_size=minibatch_size
)
print(f"Running command: {command}") # Debug print
p = Process(target=run_command, args=(command,))
p.start()
processes.append(p)
# List to hold processes
processes = []

# Wait for all processes to complete
for p in processes:
p.join()
# Loop over datasets and gpu_ids
for (dataset, minibatch_size), gpu_id in zip(configs, gpu_ids):
if gpu_id not in included_gpu_ids:
continue
command = base_command.format(
gpu_id=gpu_id, dataset=dataset, minibatch_size=minibatch_size
)
print(f"Running command: {command}") # Debug print
p = Process(target=run_command, args=(command,))
p.start()
processes.append(p)

# Wait for all processes to complete
for p in processes:
p.join()

0 comments on commit e424073

Please sign in to comment.