-
Notifications
You must be signed in to change notification settings - Fork 41
/
train.py
178 lines (133 loc) · 5.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# encoding = utf-8
import os
import pdb
import time
import numpy as np
import torch
from torch import optim
from torch.autograd import Variable
from dataloader.dataloaders import train_dataloader, val_dataloader
from network import get_model
from eval import evaluate
from options import opt, config
from options.helper import init_log, load_meta, save_meta
from utils import seed_everything
from scheduler import schedulers
from mscv.summary import create_summary_writer, write_meters_loss, write_image
from mscv.image import tensor2im
# from utils.send_sms import send_notification
import misc_utils as utils
import random
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
# 初始化
with torch.no_grad():
# 设置随机种子
if 'RANDOM_SEED' in config.MISC:
seed_everything(config.MISC.RANDOM_SEED)
# 初始化路径
save_root = os.path.join('checkpoints', opt.tag)
log_root = os.path.join('logs', opt.tag)
utils.try_make_dir(save_root)
utils.try_make_dir(log_root)
# dataloader
train_dataloader = train_dataloader
val_dataloader = val_dataloader
# 初始化日志
logger = init_log(training=True)
# 初始化训练的meta信息
meta = load_meta(new=True)
save_meta(meta)
# 初始化模型
Model = get_model(config.MODEL.NAME)
model = Model(config, logger=logger)
# 暂时还不支持多GPU
# if len(opt.gpu_ids):
# model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
model = model.to(device=opt.device)
if opt.load:
load_epoch = model.load(opt.load)
start_epoch = load_epoch + 1 if opt.resume or 'RESUME' in config.MISC else 1
elif 'LOAD' in config.MODEL:
load_epoch = model.load(config.MODEL.LOAD)
start_epoch = load_epoch + 1 if opt.resume or 'RESUME' in config.MISC else 1
else:
start_epoch = 1
model.train()
# 开始训练
print('Start training...')
start_step = (start_epoch - 1) * len(train_dataloader)
global_step = start_step
total_steps = opt.epochs * len(train_dataloader)
start = time.time()
# 定义scheduler
scheduler = model.scheduler
# tensorboard日志
writer = create_summary_writer(log_root)
start_time = time.time()
# 在日志记录transforms
logger.info('train_trasforms: ' +str(train_dataloader.dataset.transforms))
logger.info('===========================================')
if val_dataloader is not None:
logger.info('val_trasforms: ' +str(val_dataloader.dataset.transforms))
logger.info('===========================================')
# 在日志记录scheduler
if config.OPTIMIZE.SCHEDULER in schedulers:
logger.info('scheduler: (Lambda scheduler)\n' + str(schedulers[config.OPTIMIZE.SCHEDULER]))
logger.info('===========================================')
# 训练循环
try:
eval_result = ''
for epoch in range(start_epoch, opt.epochs + 1):
for iteration, sample in enumerate(train_dataloader):
global_step += 1
# 计算剩余时间
rate = (global_step - start_step) / (time.time() - start)
remaining = (total_steps - global_step) / rate
# --debug模式下只训练10个batch
if opt.debug and iteration > 10:
break
sample['global_step'] = global_step
# 更新网络参数
updated = model.update(sample)
predicted = updated.get('predicted')
pre_msg = 'Epoch:%d' % epoch
# 显示进度条
msg = f'lr:{round(scheduler.get_lr()[0], 6) : .6f} (loss) {str(model.avg_meters)} ETA: {utils.format_time(remaining)}'
utils.progress_bar(iteration, len(train_dataloader), pre_msg, msg)
# print(pre_msg, msg)
if global_step % 1000 == 0: # 每1000个step将loss写到tensorboard
write_meters_loss(writer, 'train', model.avg_meters, global_step)
# 记录训练日志
logger.info(f'Train epoch: {epoch}, lr: {round(scheduler.get_lr()[0], 6) : .6f}, (loss) ' + str(model.avg_meters))
if epoch % config.MISC.SAVE_FREQ == 0 or epoch == opt.epochs: # 最后一个epoch要保存一下
model.save(epoch)
# 训练时验证
if not opt.no_val and epoch % config.MISC.VAL_FREQ == 0:
model.eval()
evaluate(model, val_dataloader, epoch, writer, logger, data_name='val')
model.train()
if scheduler is not None:
scheduler.step()
# 保存结束信息
if opt.tag != 'default':
with open('run_log.txt', 'a') as f:
f.writelines(' Accuracy:' + eval_result + '\n')
meta = load_meta()
meta[-1]['finishtime'] = utils.get_time_stamp()
save_meta(meta)
except Exception as e:
# if not opt.debug: # debug模式不会发短信 12是短信模板字数限制
# send_notification([opt.tag[:12], str(e)[:12]], template='error')
if opt.tag != 'default':
with open('run_log.txt', 'a') as f:
f.writelines(' Error: ' + str(e)[:120] + '\n')
meta = load_meta()
meta[-1]['finishtime'] = utils.get_time_stamp()
save_meta(meta)
# print(e)
raise Exception('Error') # 再引起一个异常,这样才能打印之前的trace back信息
except: # 其他异常,如键盘中断等
meta = load_meta()
meta[-1]['finishtime'] = utils.get_time_stamp()
save_meta(meta)