Skip to content

Commit

Permalink
more generate WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 11, 2024
1 parent 1d0ee29 commit 77fd0e6
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import torch
from peft import (
AutoPeftModelForSequenceClassification,
AutoPeftModelForCausalLM,
LoraConfig,
get_peft_model,
)
from simple_parsing import Serializable, field, parse
from transformers import (
AutoModelForSequenceClassification,
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
Expand Down Expand Up @@ -98,7 +100,6 @@ def weak_processor(examples):
if task == "generate":
ctx_out = weak_tokenizer(examples["ctx"], truncation=True)
trg_out = weak_tokenizer(examples["target"], truncation=True, add_special_tokens=False)
breakpoint()
return dict(
input_ids=lolcat(ctx_out["input_ids"], trg_out["input_ids"]),
attention_mask=lolcat(ctx_out["attention_mask"], trg_out["attention_mask"]),
Expand Down Expand Up @@ -143,6 +144,7 @@ def compute_metrics(eval_pred):
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=100,
weight_decay=0.01,
report_to="none",
)

# Gather weak labels
Expand All @@ -155,18 +157,21 @@ def compute_metrics(eval_pred):
should_train = True
weak_path = root / "floor/best-ckpt"
if not weak_path.exists():
weak_model = AutoModelForSequenceClassification.from_pretrained(
autoclass = AutoModelForCausalLM if task == "generate" else AutoModelForSequenceClassification
weak_model = autoclass.from_pretrained(
cfg.weak_name, torch_dtype="auto"
)
# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01
if task == "classify":
# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01

weak_model = get_peft_model(weak_model, lora_cfg)
else:
print("Loading weak model from:", weak_path)
should_train = False

weak_model = AutoPeftModelForSequenceClassification.from_pretrained(
autoclass = AutoPeftModelForCausalLM if task == "generate" else AutoPeftModelForSequenceClassification
weak_model = autoclass.from_pretrained(
weak_path, torch_dtype="auto"
)

Expand Down

0 comments on commit 77fd0e6

Please sign in to comment.