Skip to content

Commit

Permalink
Minibatch size cfg option
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 12, 2024
1 parent f4149d6 commit 40394bf
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import torch
from datasets import Value
from peft import (
AutoPeftModelForCausalLM,
AutoPeftModelForSequenceClassification,
Expand Down Expand Up @@ -37,6 +38,9 @@ class TrainConfig(Serializable):
contamination: float = field(default=0.0)
"""What fraction of data points to remove as outliers."""

minibatch_size: int = 8
"""Size of the minibatches to use during training."""

outlier_k: int = field(default=5)
"""Number of neighbors to consider when removing outliers."""

Expand Down Expand Up @@ -150,18 +154,25 @@ def compute_metrics(eval_pred):
test = splits["test"].select_columns(cols)
train = splits["train"].select_columns(cols)

weak_test = test.map(weak_processor, batched=True)
weak_train = train.map(weak_processor, batched=True)
weak_test = test.map(weak_processor, batched=True).cast_column(
"labels", Value("int64")
)
weak_train = train.map(weak_processor, batched=True).cast_column(
"labels", Value("int64")
)

root = Path("results") / cfg.dataset
training_args = TrainingArguments(
str(root / "floor"),
adam_beta2=0.95,
gradient_accumulation_steps=8 // cfg.minibatch_size,
evaluation_strategy="epoch",
label_names=["labels"],
load_best_model_at_end=True,
logging_steps=50,
metric_for_best_model="auroc",
per_device_train_batch_size=cfg.minibatch_size,
per_device_eval_batch_size=cfg.minibatch_size,
run_name=cfg.dataset + "/floor" + cfg.run_name,
save_strategy="epoch",
save_total_limit=1,
Expand Down Expand Up @@ -189,6 +200,7 @@ def compute_metrics(eval_pred):
if task == "classify":
# HuggingFace init for the head is too large
weak_model.score.weight.data *= 0.01
weak_model.config.problem_type = "single_label_classification"

weak_model = get_peft_model(weak_model, lora_cfg)
else:
Expand Down Expand Up @@ -233,11 +245,15 @@ def compute_metrics(eval_pred):
def strong_processor(examples):
return strong_tokenizer(examples["txt"], truncation=True)

strong_train = train.map(strong_processor, batched=True).rename_column(
"hard_label", "labels"
strong_train = (
train.map(strong_processor, batched=True)
.rename_column("hard_label", "labels")
.cast_column("labels", Value("int64"))
)
ceil_test = test.map(strong_processor, batched=True).rename_column(
"hard_label", "labels"
ceil_test = (
test.map(strong_processor, batched=True)
.rename_column("hard_label", "labels")
.cast_column("labels", Value("int64"))
)

strong_ckpt = root / "ceil" / "best-ckpt"
Expand Down

0 comments on commit 40394bf

Please sign in to comment.