Skip to content

Commit

Permalink
Remove chess if-statements
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 13, 2024
1 parent c81be50 commit caaa608
Showing 1 changed file with 8 additions and 55 deletions.
63 changes: 8 additions & 55 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
import torch
from datasets import Value
from peft import (
AutoPeftModelForCausalLM,
AutoPeftModelForSequenceClassification,
LoraConfig,
get_peft_model,
)
from simple_parsing import Serializable, field, parse
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Expand Down Expand Up @@ -83,16 +81,6 @@ def move_best_ckpt(trainer: Trainer):
print(f"Best model (auroc {perf:.3f}) saved at: {dest}")


def lolcat(lol1, lol2):
# list-of-list concatenation along the second dimension
assert len(lol1) == len(lol2)
return [l1 + l2 for l1, l2 in zip(lol1, lol2)]


def lolconst(lol, const):
return [[const for _ in l_] for l_ in lol]


def train(cfg: TrainConfig):
lora_cfg = LoraConfig(target_modules=LORA_MODULES)

Expand All @@ -110,47 +98,19 @@ def train(cfg: TrainConfig):
cfg.dataset, split_sizes=dict(train=20_000, test=1_000)
)

if "txt" in splits["train"].column_names:
task = "classify"
elif "ctx" in splits["train"].column_names:
task = "generate"
else:
raise ValueError(
f"Unrecognized dataset columns: {splits['train'].column_names}"
)

def weak_processor(examples):
if task == "generate":
ctx_out = weak_tokenizer(examples["ctx"], truncation=True)
trg_out = weak_tokenizer(
examples["target"], truncation=True, add_special_tokens=False
)
return dict(
input_ids=lolcat(ctx_out["input_ids"], trg_out["input_ids"]),
attention_mask=lolcat(
ctx_out["attention_mask"], trg_out["attention_mask"]
),
labels=lolcat(
lolconst(ctx_out["input_ids"], -100), trg_out["input_ids"]
),
)

out = weak_tokenizer(examples["txt"], truncation=True)
out["labels"] = examples["hard_label"]
return out

def compute_metrics(eval_pred):
predictions, labels = map(torch.from_numpy, eval_pred)
if task == "generate":
breakpoint()
print(eval_pred)
raise NotImplementedError("Generation metrics not implemented")
return dict(
accuracy=predictions.argmax(dim=1).eq(labels).float().mean(),
auroc=roc_auc(labels, predictions[:, 1]),
)

cols = ["ctx", "target"] if task == "generate" else ["hard_label", "txt"]
cols = ["hard_label", "txt"]
test = splits["test"].select_columns(cols)
train = splits["train"].select_columns(cols)

Expand Down Expand Up @@ -191,28 +151,21 @@ def compute_metrics(eval_pred):
should_train = True
weak_path = root / "floor/best-ckpt"
if not weak_path.exists():
autoclass = (
AutoModelForCausalLM
if task == "generate"
else AutoModelForSequenceClassification
weak_model = AutoModelForSequenceClassification.from_pretrained(
cfg.weak_name, torch_dtype="auto"
)
weak_model = autoclass.from_pretrained(cfg.weak_name, torch_dtype="auto")
if task == "classify":
# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01
weak_model.config.problem_type = "single_label_classification"

# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01
weak_model.config.problem_type = "single_label_classification"
weak_model = get_peft_model(weak_model, lora_cfg)
else:
print("Loading weak model from:", weak_path)
should_train = False

autoclass = (
AutoPeftModelForCausalLM
if task == "generate"
else AutoPeftModelForSequenceClassification
weak_model = AutoPeftModelForSequenceClassification.from_pretrained(
weak_path, torch_dtype="auto"
)
weak_model = autoclass.from_pretrained(weak_path, torch_dtype="auto")

weak_model.config.pad_token_id = weak_tokenizer.pad_token_id
trainer = Trainer(
Expand Down

0 comments on commit caaa608

Please sign in to comment.