Skip to content

Commit

Permalink
ready for runs
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 11, 2024
1 parent e9db957 commit c769e15
Show file tree
Hide file tree
Showing 8 changed files with 927 additions and 51 deletions.
809 changes: 809 additions & 0 deletions notebooks/wandb.ipynb

Large diffs are not rendered by default.

20 changes: 12 additions & 8 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ def run_train(cfg: SFTConfig):
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict
)

train_halves = splits["train"].train_test_split(test_size=0.5, seed=seed)
splits["weak_train"] = train_halves["train"]
splits["strong_train"] = train_halves["test"]

cols = ["hard_label", "txt"]
splits = splits.select_columns(cols).rename_column("hard_label", "labels")
for split in splits:
splits[split] = splits[split].add_column("gt_labels", splits[split]["labels"])

print(
f"Example:\n\n{splits['train'][0]['txt']}\n\nLabel: {splits['train'][0]['labels']}"
f"Example:\n\n{splits['strong_train'][0]['txt']}\n\nLabel: {splits['strong_train'][0]['labels']}"
)

root = Path(cfg.results_folder) / cfg.run_name
Expand Down Expand Up @@ -69,12 +73,12 @@ def get_model_and_run_name(model_name, current_name):
train_args["learning_rate"] = cfg.weak_lr
weak_ds_dict = DatasetDict(
{
"train": splits["train"],
"train": splits["weak_train"],
"val": splits["val"],
"test": splits["test"],
}
)
weak_predict_dict = {"train": splits["train"], "val": splits["val"]}
weak_predict_dict = {"train": splits["strong_train"], "val": splits["val"]}
train(
weak_ds_dict,
model_cfg,
Expand All @@ -92,7 +96,7 @@ def get_model_and_run_name(model_name, current_name):
train_args["learning_rate"] = cfg.strong_lr
strong_ds_dict = DatasetDict(
{
"train": splits["train"],
"train": splits["strong_train"],
"val": splits["val"],
"test": splits["test"],
}
Expand Down Expand Up @@ -120,7 +124,7 @@ def get_model_and_run_name(model_name, current_name):
w2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
splits["strong_train"]
.remove_columns("labels")
.add_column("labels", weak_train_preds_ds["soft_pred"]) # type: ignore
),
Expand All @@ -134,7 +138,7 @@ def get_model_and_run_name(model_name, current_name):
)
# assert (weak_train_preds_ds["id"] == w2s_ds_dict["train"]["id"])
# assert (weak_val_preds_ds["id"] == w2s_ds_dict["val"]["id"])
w2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
w2s_predict_dict = {"train": splits["strong_train"], "val": splits["val"]}
train(
w2s_ds_dict,
model_cfg,
Expand Down Expand Up @@ -166,7 +170,7 @@ def get_model_and_run_name(model_name, current_name):
s2s_ds_dict = DatasetDict(
{
"train": (
splits["train"]
splits["strong_train"]
.remove_columns("labels")
.add_column("labels", prev_train_preds_ds["soft_pred"]) # type: ignore
),
Expand All @@ -180,7 +184,7 @@ def get_model_and_run_name(model_name, current_name):
)
# assert (prev_train_preds_ds["id"] == s2s_ds_dict["train"]["id"])
# assert (prev_val_preds_ds["id"] == s2s_ds_dict["val"]["id"])
s2s_predict_dict = {"train": splits["train"], "val": splits["val"]}
s2s_predict_dict = {"strong_train": splits["train"], "val": splits["val"]}
train(
s2s_ds_dict,
model_cfg,
Expand Down
9 changes: 6 additions & 3 deletions run_eight.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Define the datasets and respective GPU ids
# list of tuples with dataset name and minibatch size
configs = [
("boolq", 2),
("boolq", 1),
("anli-r2", 8),
("cosmos_qa", 4),
("mc_taco", 4),
Expand Down Expand Up @@ -33,8 +33,9 @@ def run_command(command):

if __name__ == "__main__":
if "--gpus" in sys.argv:
included_gpu_ids = [int(gpu_id) for gpu_id in argv[argv.index("--gpus") + 1].split(",")]
other_argv = sys.argv[1:argv.index("--gpus")] + sys.argv[argv.index("--gpus") + 2:]
i = sys.argv.index("--gpus")
included_gpu_ids = [int(gpu_id) for gpu_id in sys.argv[i + 1].split(",")]
other_argv = sys.argv[1:i] + sys.argv[i + 2:]

else:
included_gpu_ids = gpu_ids
Expand All @@ -55,6 +56,8 @@ def run_command(command):
minibatch_size=minibatch_size,
argv=argv,
)
squished = command.replace(" ", "_").replace("/", "_").replace("=", "_")
command += f" | tee logs/{squished}.log"
print(f"Running command: {command}") # Debug print
p = Process(target=run_command, args=(command,))
p.start()
Expand Down
76 changes: 76 additions & 0 deletions run_monday.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from w2s.gpu_pool import gpu_map
from w2s.ds_registry import VALID_DATASETS
from datetime import datetime
import sys

def add_log(command, ds, task):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
rname = task.split("--run_name ")[1]
log_file = f"logs/{timestamp}_{ds}_{rname}"
if len(log_file) > 200:
log_file = log_file[:200]
command += f" > {log_file}.log 2>&1"

return command

date = "0610"

tasks = [
f"--loss xent --run_name repro_{date}_xent",
f"--s2s_iters 2 --run_name repro_{date}_strong2strong",
f"--probe_relabel --probe knn --run_name repro_{date}_probe_knn",
f"--probe_relabel --probe logreg --run_name repro_{date}_probe_logreg",
f"--probe_filter --probe knn --run_name repro_{date}_filter_knn",
f"--probe_filter --probe logreg --run_name repro_{date}_filter_logreg",
f"--probe_filter --probe topo --run_name repro_{date}_filter_topo",
f"--loss window --radius midweak --run_name repro_{date}_window_mid",
f"--loss entropy --run_name repro_{date}_entropy",
]

jobs = []

configs = [
# ("boolq", 1),
# ("anli-r2", 8),
# ("cosmos_qa", 4),
# ("mc_taco", 4),
# ("sciq", 4),
# ("paws", 16),
# ("twitter-sentiment", 8),
# ("wic", 8),
]

for ds in VALID_DATASETS:
if ds not in [c[0] for c in configs]:
configs.append((ds, 1))


for task in tasks:
for ds, minibatch in configs:
jobs.append(
add_log(
f"python run.py --dataset {ds} --minibatch_size {minibatch} {task}",
ds,
task
)
)

halfpoint = len(jobs) // 2

# CHANGE ME on shared-ord
jobs = jobs[:halfpoint]

if __name__ == "__main__":
# usage: python run_friday.py 1,2,3,4,5,6,7
if len(sys.argv) == 1:
# default to all GPUs
gpus = range(8)
else:
gpus = [int(gpu) for gpu in sys.argv[1].split(",")]

for job in jobs:
print(job)
print()
print(f"Running on GPUs: {gpus}")

gpu_map(gpus, jobs)
13 changes: 11 additions & 2 deletions w2s/gpu_pool.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from multiprocessing import Manager, Process
import subprocess
from datetime import datetime

# Function that runs the job on a GPU
def run_on_gpu(gpu: int, job: str):
print(f"Starting on GPU {gpu}: {job}")
print("at time:", datetime.now())
command = f"CUDA_VISIBLE_DEVICES={gpu} {job}"
subprocess.run(command, shell=True, check=True)
print(f"Finished on GPU {gpu}: {job}")
try:
subprocess.run(command, shell=True, check=True)
except Exception as e:
print(f"[WARN] Error on GPU {gpu}: {job}")
print(e)
else:
print(f"Finished on GPU {gpu}: {job}")
finally:
print("at time:", datetime.now())

# Worker function that gets jobs and runs them on a specific GPU
def worker(gpu, jobs, lock):
Expand Down
43 changes: 6 additions & 37 deletions w2s/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def log_confidence_loss(
aux_coef: float = 0.5,
balance_batch: bool = False,
harden: bool = True,
buffer: list = None,
buffer_size: int = 32,
):
"""
This is similar to the loss in Burns et al., except that it also optionally
Expand All @@ -118,9 +120,11 @@ def log_confidence_loss(

coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef
preds = torch.softmax(logits, dim=-1)
buffer += list(preds[:, 0].detach())
buffer = buffer[-buffer_size:]

if harden:
threshold = torch.quantile(preds[:, 0], prior)
threshold = torch.quantile(torch.stack(buffer), prior)
target_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
Expand All @@ -130,39 +134,4 @@ def log_confidence_loss(

labels_binary = torch.stack([1.0 - labels, labels], dim=1)
target = labels_binary * (1 - coef) + target_preds.detach() * coef
return torch.nn.functional.cross_entropy(logits, target)


def log_confidence_loss2(
logits,
labels,
step: int,
warmup_steps: int = 200,
aux_coef: float = 0.5,
balance_batch: bool = False,
):
"""
This one uses a batch-independent threshold of 0.5, and then finally optionally balances
the batch by mean-subtracting the log-odds of the target.
"""
logits = logits.float()
labels = labels.float()

coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef
preds = torch.softmax(logits, dim=-1)

threshold = 0.5
strong_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
)
labels_binary = torch.stack([1.0 - labels, labels], dim=1)
target = labels_binary * (1 - coef) + strong_preds.detach() * coef

if balance_batch:
logodds_target = torch.log(target) - torch.log1p(-target)
target = torch.sigmoid(
logodds_target - logodds_target.mean(dim=0, keepdim=True)
)

return torch.nn.functional.cross_entropy(logits, target)
return torch.nn.functional.cross_entropy(logits, target)
2 changes: 1 addition & 1 deletion w2s/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LogisticProbeConfig(ProbeConfig):

@dataclass
class TopoProbeConfig(ProbeConfig):
k_cc: int = 100
k_cc: int = 50
k_zeta: int = 50
modified: bool = False

Expand Down
6 changes: 6 additions & 0 deletions w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(
self.loss_name = loss_name
self.loss_cfg = loss_cfg
self.transfer = transfer
if loss_name == "logconf":
self.buffer = []
# self.buffer_size = kwargs["buffer_size"]


def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels").float()
Expand All @@ -61,6 +65,8 @@ def compute_loss(self, model, inputs, return_outputs=False):
warmup_steps=self.loss_cfg.logconf_warmup_steps,
balance_batch=self.loss_cfg.balance_batch,
harden=True,
buffer=self.buffer,
# buffer_size=self.buffer_size,
)
elif self.loss_name == 'entropy':
loss = log_confidence_loss(
Expand Down

0 comments on commit c769e15

Please sign in to comment.