forked from moon-hotel/TransformerClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
116 lines (104 loc) · 5.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch.nn as nn
import torch
from config import Config
from data_helper import LoadSentenceClassificationDataset, my_tokenizer
from ClassificationModel import ClassificationModel
import os
import time
class CustomSchedule(nn.Module):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = torch.tensor(d_model, dtype=torch.float32)
self.warmup_steps = warmup_steps
self.step = 1.
def __call__(self):
arg1 = self.step ** -0.5
arg2 = self.step * (self.warmup_steps ** -1.5)
self.step += 1.
return (self.d_model ** -0.5) * min(arg1, arg2)
def train_model(config):
data_loader = LoadSentenceClassificationDataset(config.train_corpus_file_paths,
my_tokenizer,
batch_size=config.batch_size,
min_freq=config.min_freq,
max_sen_len=config.max_sen_len)
train_iter, test_iter = data_loader.load_train_val_test_data(
config.train_corpus_file_paths, config.test_corpus_file_paths)
classification_model = ClassificationModel(vocab_size=len(data_loader.vocab),
d_model=config.d_model,
nhead=config.num_head,
num_encoder_layers=config.num_encoder_layers,
dim_feedforward=config.dim_feedforward,
dim_classification=config.dim_classification,
num_classification=config.num_class,
dropout=config.dropout)
for p in classification_model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
model_save_path = os.path.join(config.model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
classification_model.load_state_dict(loaded_paras)
print("## 成功载入已有模型,进行追加训练......")
classification_model = classification_model.to(config.device)
loss_fn = torch.nn.CrossEntropyLoss()
learning_rate = CustomSchedule(config.d_model)
optimizer = torch.optim.Adam(classification_model.parameters(),
lr=0.,
betas=(config.beta1, config.beta2),
eps=config.epsilon)
classification_model.train()
max_test_acc = 0
for epoch in range(config.epochs):
losses = 0
start_time = time.time()
for idx, (sample, label) in enumerate(train_iter):
sample = sample.to(config.device) # [src_len, batch_size]
label = label.to(config.device)
padding_mask = (sample == data_loader.PAD_IDX).transpose(0, 1)
logits = classification_model(sample,
src_key_padding_mask=padding_mask) # [batch_size,num_class]
optimizer.zero_grad()
loss = loss_fn(logits, label)
loss.backward()
lr = learning_rate()
for p in optimizer.param_groups:
p['lr'] = lr
optimizer.step()
losses += loss.item()
acc = (logits.argmax(1) == label).float().mean()
if idx % 10 == 0:
print(f"Epoch: {epoch}, Batch[{idx}/{len(train_iter)}], "
f"Train loss :{loss.item():.3f}, Train acc: {acc:.3f}")
end_time = time.time()
train_loss = losses / len(train_iter)
print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s")
if (epoch + 1) % config.model_save_per_epoch == 0:
acc = evaluate(test_iter, classification_model, config.device)
print(f"Accuracy on test {acc:.3f}, max acc on test {max_test_acc:.3f}")
if acc > max_test_acc:
max_test_acc = acc
torch.save(classification_model.state_dict(), model_save_path)
def evaluate(data_iter, model, device):
model.eval()
with torch.no_grad():
acc_sum, n = 0.0, 0
for x, y in data_iter:
x, y = x.to(device), y.to(device)
logits = model(x)
acc_sum += (logits.argmax(1) == y).float().sum().item()
n += len(y)
model.train()
return acc_sum / n
if __name__ == '__main__':
config = Config()
train_model(config)
"""
Epoch: 9, Batch: [410/469], Train loss 0.186, Train acc: 0.938
Epoch: 9, Batch: [420/469], Train loss 0.150, Train acc: 0.938
Epoch: 9, Batch: [430/469], Train loss 0.269, Train acc: 0.941
Epoch: 9, Batch: [440/469], Train loss 0.197, Train acc: 0.925
Epoch: 9, Batch: [450/469], Train loss 0.245, Train acc: 0.917
Epoch: 9, Batch: [460/469], Train loss 0.272, Train acc: 0.902
Accuracy on test 0.886
"""