Skip to content

Commit

Permalink
Fresh model for w2s
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 10, 2024
1 parent 77f9fce commit a8e3e8d
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,13 @@ def compute_metrics(eval_pred):
adam_beta2=0.95,
evaluation_strategy="epoch",
label_names=["labels"],
learning_rate=2e-5,
load_best_model_at_end=True,
logging_steps=50,
metric_for_best_model="auroc",
per_device_train_batch_size=16,
save_strategy="epoch",
save_total_limit=1,
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=100,
weight_decay=0.01,
)

Expand All @@ -124,6 +123,9 @@ def compute_metrics(eval_pred):
weak_model = AutoModelForSequenceClassification.from_pretrained(
cfg.weak_name, torch_dtype="auto"
)
# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01

weak_model = get_peft_model(weak_model, lora_cfg)
else:
print("Loading weak model from:", weak_path)
Expand Down Expand Up @@ -169,16 +171,20 @@ def compute_metrics(eval_pred):
strong_model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
# HuggingFace init for the head is too large
strong_model.score.weight.data *= 0.01

strong_model.config.pad_token_id = (
strong_tokenizer.pad_token_id
) = strong_tokenizer.eos_token_id

def strong_processor(examples):
return strong_tokenizer(examples["txt"], truncation=True)

strong_train = splits["train"].map(strong_processor, batched=True)
ceil_test = splits["train"].map(strong_processor, batched=True)
ceil_test = splits["test"].rename_column("hard_label", "labels")
strong_train = train.map(strong_processor, batched=True)
ceil_test = test.map(strong_processor, batched=True).rename_column(
"hard_label", "labels"
)

training_args.output_dir = str(root / "ceil")
trainer = Trainer(
Expand All @@ -193,6 +199,13 @@ def strong_processor(examples):
trainer.train()
move_best_ckpt(trainer)

# Init a fresh model for w2s experiment
strong_model = AutoModelForSequenceClassification.from_pretrained(
STRONG_NAME, torch_dtype="auto", device_map={"": "cuda"}
)
# HuggingFace init for the head is too large
strong_model.score.weight.data *= 0.01

# Make sure that we use CrossEntropyLoss
strong_model.config.problem_type = "single_label_classification"

Expand Down

0 comments on commit a8e3e8d

Please sign in to comment.