-
Notifications
You must be signed in to change notification settings - Fork 0
/
sft.py
158 lines (134 loc) · 5.04 KB
/
sft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import json
from pathlib import Path
from typing import Union
import torch
from datasets import DatasetDict
from transformers import (
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)
import wandb
from w2s.loss import log_confidence_loss, confidence_window_loss
from w2s.model import ModelConfig, init_model_and_tokenizer
from w2s.roc_auc import roc_auc
from w2s.sft_utils import (
clear_mem,
get_gpu_mem_used,
move_best_ckpt,
)
from w2s.sft_config import LossConfig
class CustomLossTrainer(Trainer):
def __init__(
self,
loss_name: str,
loss_cfg: LossConfig,
transfer: bool,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.loss_name = loss_name
self.loss_cfg = loss_cfg
self.transfer = transfer
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels").float()
outputs = model(**inputs)
if self.loss_name == 'logconf':
loss = log_confidence_loss(
outputs.logits,
labels,
self.state.global_step,
aux_coef=(self.loss_cfg.logconf_weight if self.transfer else 0.),
warmup_steps=self.loss_cfg.logconf_warmup_steps,
balance_batch=self.loss_cfg.balance_batch,
)
elif self.loss_name == 'window':
loss = confidence_window_loss(
outputs.logits,
labels,
radius=self.loss_cfg.radius,
)
else:
raise ValueError(f"Unknown loss function: {self.loss_name}")
return (loss, outputs) if return_outputs else loss
def train(
ds_dict: DatasetDict,
model_cfg: ModelConfig,
train_args: TrainingArguments,
cfg: dict,
transfer: bool,
predict_dict: Union[DatasetDict, dict, None] = None,
):
"""
ds_dict: DatasetDict with splits for train, val, test, and (optionally) predict,
with columns "txt" and "labels"
model_cfg: ModelConfig with the model name and whether to enable LoRA
train_args: TrainingArguments with the training hyperparameters
cfg: a dictionary containing all the relevant details for reproducibility.
This will be updated with your train_args and model_cfg before saving.
logconf_weight: the weight for the log confidence loss
logconf_warmup_steps: the number of steps to linearly increase the logconf_weight
balance_batch: whether to balance the batch with the log confidence loss
This function trains a model on ds_dict["train"], uses ds_dict["val"] for early stopping,
and evaluates on ds_dict["test"].
It also optionally predicts on ds_dict["predict"] and saves the predictions.
"""
save_dir = Path(train_args.output_dir)
results_path = save_dir / "results.json"
if results_path.exists():
print(
f"Results already exist at {results_path}. Skipping training and evaluation."
)
return
clear_mem()
print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use before training")
model, tokenizer = init_model_and_tokenizer(model_cfg)
def process(examples):
out = tokenizer(examples["txt"], truncation=True)
return out
ds_dict = ds_dict.map(process, batched=True)
def compute_metrics(eval_pred):
predictions, labels = map(torch.from_numpy, eval_pred)
hard_labels = (labels > 0.5).long()
return dict(
accuracy=predictions.argmax(dim=1).eq(hard_labels).float().mean(),
auroc=roc_auc(hard_labels, predictions[:, 1]),
)
trainer = CustomLossTrainer(
loss_name=cfg.loss_name,
loss_cfg=cfg.loss,
transfer=transfer,
args=train_args,
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(tokenizer),
eval_dataset={k: ds_dict[k] for k in ["val", "test"]},
model=model,
tokenizer=tokenizer,
train_dataset=ds_dict["train"],
)
# train
trainer.train()
# evaluate on test dataset
eval_results = trainer.evaluate(ds_dict["test"]) # type: ignore
move_best_ckpt(trainer)
# save results
with open(results_path, "w") as f:
json.dump(eval_results, f, indent=2)
# save config
with open(save_dir / "config.json", "w") as f:
cfg["model"] = model_cfg.to_dict()
cfg["train_args"] = train_args.to_dict()
cfg["logconf_weight"] = logconf_weight
json.dump(cfg, f, indent=2)
wandb.config.update(cfg)
# save predictions
if predict_dict is not None:
for name, predict_ds in predict_dict.items():
predict_ds = predict_ds.map(process, batched=True)
print("Gathering predictions for", name)
pred_logits = torch.from_numpy(trainer.predict(predict_ds).predictions)
preds = pred_logits.softmax(-1)[:, 1].cpu().float().numpy()
pred_ds = predict_ds.add_column("soft_pred", preds)
pred_ds.save_to_disk(str(save_dir / "predictions" / name))
wandb.finish()