Skip to content

Commit

Permalink
improve default hparams
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 17, 2024
1 parent f5991ea commit c444ba3
Show file tree
Hide file tree
Showing 19 changed files with 886 additions and 622 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ repos:
rev: 'v0.0.262'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
args: [--fix, --exit-non-zero-on-fix, --line-length=100]
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ dependencies = [
"peft",
"scipy",
"simple-parsing",
"fire ~= 0.4",
"pynvml ~= 11.5",
"scikit-learn ~= 1.3.2",
# 4.0 introduced the breaking change of using return_dict=True by default
"transformers>=4.0.0",
"wandb",
Expand Down
67 changes: 67 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pathlib import Path

from simple_parsing import parse
from transformers import (
TrainingArguments,
)

from underspec.ds_registry import load_and_process_dataset
from underspec.model import ModelConfig
from underspec.sft import train
from underspec.sft_config import SFTConfig
from underspec.utils import get_config_foldername


def run_train(cfg: SFTConfig):
splits = load_and_process_dataset(
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict
)

cols = ["hard_label", "txt"]
splits = splits.select_columns(cols).rename_column("hard_label", "labels")
print(
f"Example:\n\n{splits['train'][0]['txt']}\n\nLabel: {splits['train'][0]['labels']}"
)

root = Path(cfg.results_folder) / cfg.run_name
cfg_name = get_config_foldername(vars(cfg))
model_last = cfg.model_name.split("/")[-1]
train_args = TrainingArguments(
output_dir=str(root / cfg_name),
num_train_epochs=cfg.n_epochs,
adam_beta2=0.95,
gradient_accumulation_steps=cfg.batch_size // cfg.minibatch_size,
evaluation_strategy="steps",
label_names=["labels"],
load_best_model_at_end=True,
logging_steps=25,
metric_for_best_model="eval_loss",
greater_is_better=False,
per_device_train_batch_size=cfg.minibatch_size,
per_device_eval_batch_size=cfg.minibatch_size,
run_name=f"{cfg.run_name}-{cfg.dataset}-{model_last}",
save_strategy="steps",
save_total_limit=cfg.save_total_limit,
tf32=True, # Use Tensor Cores even for fp32 matmuls
warmup_steps=cfg.n_warmup_steps,
weight_decay=cfg.weight_decay,
learning_rate=cfg.lr,
lr_scheduler_type=cfg.lr_schedule,
eval_steps=cfg.eval_every,
save_steps=cfg.save_every,
)

model_cfg = ModelConfig(name=cfg.model_name, enable_lora=not cfg.disable_lora)
train(
splits,
model_cfg,
train_args,
cfg.loss,
cfg.store_pre_hiddens,
cfg.store_post_hiddens,
cfg.to_dict(),
)


if __name__ == "__main__":
run_train(parse(SFTConfig))
Loading

0 comments on commit c444ba3

Please sign in to comment.