Skip to content

Commit

Permalink
Merge branch 'main' of github.com:EleutherAI/w2s into main
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 11, 2024
2 parents f57cb4c + 77fd0e6 commit 3e0f50d
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import torch
from peft import (
AutoPeftModelForCausalLM,
AutoPeftModelForSequenceClassification,
LoraConfig,
get_peft_model,
)
from simple_parsing import Serializable, field, parse
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Expand Down Expand Up @@ -108,7 +110,6 @@ def weak_processor(examples):
trg_out = weak_tokenizer(
examples["target"], truncation=True, add_special_tokens=False
)
breakpoint()
return dict(
input_ids=lolcat(ctx_out["input_ids"], trg_out["input_ids"]),
attention_mask=lolcat(
Expand Down Expand Up @@ -155,6 +156,7 @@ def compute_metrics(eval_pred):
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=100,
weight_decay=0.01,
report_to="none",
)

# Gather weak labels
Expand All @@ -167,20 +169,27 @@ def compute_metrics(eval_pred):
should_train = True
weak_path = root / "floor/best-ckpt"
if not weak_path.exists():
weak_model = AutoModelForSequenceClassification.from_pretrained(
cfg.weak_name, torch_dtype="auto"
autoclass = (
AutoModelForCausalLM
if task == "generate"
else AutoModelForSequenceClassification
)
# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01
weak_model = autoclass.from_pretrained(cfg.weak_name, torch_dtype="auto")
if task == "classify":
# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01

weak_model = get_peft_model(weak_model, lora_cfg)
else:
print("Loading weak model from:", weak_path)
should_train = False

weak_model = AutoPeftModelForSequenceClassification.from_pretrained(
weak_path, torch_dtype="auto"
autoclass = (
AutoPeftModelForCausalLM
if task == "generate"
else AutoPeftModelForSequenceClassification
)
weak_model = autoclass.from_pretrained(weak_path, torch_dtype="auto")

weak_model.config.pad_token_id = weak_tokenizer.pad_token_id
trainer = Trainer(
Expand Down

0 comments on commit 3e0f50d

Please sign in to comment.