Skip to content

Commit

Permalink
misc improvements and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 11, 2024
1 parent 5334013 commit 1a81f85
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 30 deletions.
5 changes: 5 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import os

import torch
from datasets import DatasetDict, load_from_disk
Expand All @@ -15,6 +16,7 @@


def run_train(cfg: SFTConfig):
print(f"Loading and processing dataset {cfg.dataset}")
splits = load_and_process_dataset(
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict
)
Expand Down Expand Up @@ -105,6 +107,7 @@ def get_model_and_run_name(model_name, current_name):

# load weak predictions
weak_preds_root = shared_root / cfg_name / "weak" / "predictions"
print(f"Loading weak predictions from {weak_preds_root}")
weak_train_preds_ds = load_from_disk(str(weak_preds_root / "train"))
weak_val_preds_ds = load_from_disk(str(weak_preds_root / "val"))

Expand Down Expand Up @@ -150,6 +153,7 @@ def get_model_and_run_name(model_name, current_name):

# load prev predictions
prev_preds_root = root / cfg_name / prev / "predictions"
print(f"Loading {prev} predictions from {prev_preds_root}")
prev_train_preds_ds = load_from_disk(str(prev_preds_root / "train"))
prev_val_preds_ds = load_from_disk(str(prev_preds_root / "val"))

Expand Down Expand Up @@ -190,4 +194,5 @@ def get_model_and_run_name(model_name, current_name):
prev = f"s2s-{s2s_iter}"

if __name__ == "__main__":
os.environ["HF_DATASETS_CACHE"] = "/home/adam/hf_cache"
run_train(parse(SFTConfig))
112 changes: 112 additions & 0 deletions run_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from pathlib import Path
import os

import torch
from datasets import DatasetDict, load_from_disk
from simple_parsing import parse
from transformers import (
TrainingArguments,
)
import numpy as np

from w2s.ds_registry import load_and_process_dataset
from w2s.model import ModelConfig
from w2s.sft import train
from w2s.sft_config import SFTConfig
from w2s.utils import get_config_foldername


def run_train(cfg: SFTConfig):
print(f"Loading and processing dataset {cfg.dataset}")
splits = load_and_process_dataset(
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict
)

cols = ["txt"]
splits = splits.select_columns(cols)
for split in splits:
ps = np.random.random(len(splits[split]))
splits[split] = splits[split].add_column("labels", ps.astype(np.float32))

print(
f"Example:\n\n{splits['train'][0]['txt']}\n\nLabel: {splits['train'][0]['labels']}"
)

root = Path(cfg.results_folder) / cfg.run_name
shared_root = root
cfg_name = cfg.dataset
train_args: dict = dict(
num_train_epochs=cfg.n_epochs,
adam_beta2=0.95,
gradient_accumulation_steps=cfg.batch_size // cfg.minibatch_size,
eval_strategy="steps",
label_names=["labels"],
load_best_model_at_end=cfg.load_best_model_at_end,
logging_steps=25,
metric_for_best_model=cfg.metric_for_best_model,
greater_is_better=cfg.greater_is_better,
per_device_train_batch_size=cfg.minibatch_size,
per_device_eval_batch_size=cfg.minibatch_size,
save_strategy="steps",
save_total_limit=cfg.save_total_limit,
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=cfg.n_warmup_steps,
weight_decay=cfg.weight_decay,
lr_scheduler_type=cfg.lr_schedule,
eval_steps=cfg.eval_every,
save_steps=cfg.save_every,
)

def get_model_and_run_name(model_name, current_name):
model_last = model_name.split("/")[-1]
model_cfg = ModelConfig(name=model_name, enable_lora=not cfg.disable_lora)
run_name = f"{current_name}-{cfg.run_name}-{cfg.dataset}-{model_last}"
return model_cfg, run_name

# train weak floor, get predictions
print("\n\033[32m===== Training weak model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.weak_model_name, "weak")
train_args["run_name"] = run_name
train_args["output_dir"] = str(shared_root / cfg_name / "weak")
train_args["learning_rate"] = cfg.weak_lr
weak_ds_dict = DatasetDict(
{
"train": splits["train"],
"val": splits["val"],
"test": splits["test"],
}
)
weak_predict_dict = {"train": splits["train"], "val": splits["val"]}
train(
weak_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=False,
)

# train strong ceil
print("\n\033[32m===== Training strong model =====\033[0m")
model_cfg, run_name = get_model_and_run_name(cfg.strong_model_name, "strong")
train_args["run_name"] = run_name
train_args["output_dir"] = str(shared_root / cfg_name / "strong")
train_args["learning_rate"] = cfg.strong_lr
strong_ds_dict = DatasetDict(
{
"train": splits["train"],
"val": splits["val"],
"test": splits["test"],
}
)
train(
strong_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=False,
)


if __name__ == "__main__":
os.environ["HF_DATASETS_CACHE"] = "/home/adam/hf_cache"
run_train(parse(SFTConfig))
39 changes: 39 additions & 0 deletions w2s/gpu_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from multiprocessing import Manager, Process
import subprocess

# Function that runs the job on a GPU
def run_on_gpu(gpu: int, job: str):
print(f"Starting on GPU {gpu}: {job}")
command = f"CUDA_VISIBLE_DEVICES={gpu} {job}"
subprocess.run(command, shell=True, check=True)
print(f"Finished on GPU {gpu}: {job}")

# Worker function that gets jobs and runs them on a specific GPU
def worker(gpu, jobs, lock):
while True:
with lock:
if not jobs:
print(f"GPU {gpu} has no more jobs.")
return # No more jobs to process
job = jobs.pop(0)

run_on_gpu(gpu, job)

def gpu_map(gpus, jobs):
# Create a shared job list and a lock
manager = Manager()
jobs = manager.list(jobs)
lock = manager.Lock()

# Create and start worker processes, each assigned to a specific GPU
processes = []
for gpu in gpus:
p = Process(target=worker, args=(gpu, jobs, lock))
processes.append(p)
p.start()

# Wait for all worker processes to finish
for p in processes:
p.join()

print("All jobs finished.")
21 changes: 20 additions & 1 deletion w2s/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Union
from w2s.sft_utils import literal
Expand Down Expand Up @@ -30,11 +31,16 @@ class LogEntropyLossConfig(LogConfidenceLossConfig):
class CrossEntropyLossConfig(LossConfig):
pass

@dataclass
class KLDivergenceLossConfig(LossConfig):
pass

LOSS_CONFIGS = {
"logconf": LogConfidenceLossConfig,
"window": ConfidenceWindowLossConfig,
"entropy": LogEntropyLossConfig,
"xent": CrossEntropyLossConfig,
"kl": KLDivergenceLossConfig,
}


Expand Down Expand Up @@ -75,6 +81,19 @@ def cross_entropy_loss(
return torch.nn.functional.cross_entropy(logits, target)


def kl_divergence_loss(
logits,
labels,
):
logits = logits.float()
labels = labels.float()

target = torch.stack([1.0 - labels, labels], dim=1)
log_preds = torch.log_softmax(logits, dim=-1)

return F.kl_div(log_preds, target, reduction="batchmean")


def log_confidence_loss(
logits,
labels,
Expand All @@ -95,7 +114,7 @@ def log_confidence_loss(
labels = torch.sigmoid(logodds_labels - logodds_labels.mean())
prior = 0.5
else:
prior = labels.mean()
prior = labels.mean() if labels.shape[0] > 1 else 0.5

coef = aux_coef * min(1.0, step / warmup_steps) if warmup_steps > 0 else aux_coef
preds = torch.softmax(logits, dim=-1)
Expand Down
24 changes: 18 additions & 6 deletions w2s/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,25 @@ def forward(self, hiddens):
).to(self.output_dtype)


def init_model_and_tokenizer(cfg: ModelConfig):
def init_tokenizer(cfg: ModelConfig) -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(cfg.name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

return tokenizer


def init_model(tokenizer, cfg: ModelConfig):
model = AutoModelForSequenceClassification.from_pretrained(
cfg.name, torch_dtype="auto", device_map={"": "cuda"}
cfg.name, torch_dtype="auto", device_map={"": "cuda"},
# force_download=True,
)

if cfg.lora_modules is None and cfg.enable_lora:
cfg.lora_modules = MODEL_REGISTRY.get(cfg.name, {}).get(
"lora_modules", DEFAULT_LORA_MODULES
)

tokenizer = AutoTokenizer.from_pretrained(cfg.name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

model.config.pad_token_id = tokenizer.pad_token_id # type: ignore
model.score.weight.data *= 0.01
model.config.problem_type = "single_label_classification"
Expand Down Expand Up @@ -97,6 +102,13 @@ def init_model_and_tokenizer(cfg: ModelConfig):
if p.requires_grad:
p.data = p.data.float()

return model


def init_model_and_tokenizer(cfg: ModelConfig):
tokenizer = init_tokenizer(cfg)
model = init_model(tokenizer, cfg)

return model, tokenizer


Expand Down
8 changes: 5 additions & 3 deletions w2s/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ def to_dict(self):

@dataclass
class KnnProbeConfig(ProbeConfig):
k: int = 20
k: int = 50

@dataclass
class LogisticProbeConfig(ProbeConfig):
l2p: float = 1e-3

@dataclass
class TopoProbeConfig(ProbeConfig):
k_cc: int = 20
k_zeta: int = 20
k_cc: int = 100
k_zeta: int = 50
modified: bool = False


Expand Down Expand Up @@ -81,10 +81,12 @@ def __init__(self, config: LogisticProbeConfig):
self.l2p = config.l2p

def fit(self, acts, labels):
acts = acts.to(torch.float32)
self.clf = Classifier(acts.shape[1], num_classes=1, device=acts.device)
self.clf.fit(acts, labels, l2_penalty=self.l2p)

def predict(self, acts):
acts = acts.to(torch.float32)
preds = torch.sigmoid(self.clf(acts))
return preds

Expand Down
Loading

0 comments on commit 1a81f85

Please sign in to comment.