forked from SeanNaren/deepspeech.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
state.py
168 lines (143 loc) · 5.65 KB
/
state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
from model import DeepSpeech
from utils import remove_parallel_wrapper
class ResultState:
def __init__(self,
loss_results,
wer_results,
cer_results):
self.loss_results = loss_results
self.wer_results = wer_results
self.cer_results = cer_results
def add_results(self,
epoch,
loss_result,
wer_result,
cer_result):
self.loss_results[epoch] = loss_result
self.wer_results[epoch] = wer_result
self.cer_results[epoch] = cer_result
def serialize_state(self):
return {
'loss_results': self.loss_results,
'wer_results': self.wer_results,
'cer_results': self.cer_results
}
class TrainingState:
def __init__(self,
model,
result_state=None,
optim_state=None,
amp_state=None,
best_wer=None,
avg_loss=0,
epoch=0,
training_step=0):
"""
Wraps around training model and states for saving/loading convenience.
For backwards compatibility there are more states being saved than necessary.
"""
self.model = model
self.result_state = result_state
self.optim_state = optim_state
self.amp_state = amp_state
self.best_wer = best_wer
self.avg_loss = avg_loss
self.epoch = epoch
self.training_step = training_step
def track_optim_state(self, optimizer):
self.optim_state = optimizer.state_dict()
def track_amp_state(self, amp):
self.amp_state = amp.state_dict()
def init_results_tracking(self, epochs):
self.result_state = ResultState(loss_results=torch.IntTensor(epochs),
wer_results=torch.IntTensor(epochs),
cer_results=torch.IntTensor(epochs))
def add_results(self,
epoch,
loss_result,
wer_result,
cer_result):
self.result_state.add_results(epoch=epoch,
loss_result=loss_result,
wer_result=wer_result,
cer_result=cer_result)
def init_finetune_states(self, epochs):
"""
Resets the training environment, but keeps model specific states in tact.
This is when fine-tuning a model on another dataset where training is to be reset but the model
weights are to be loaded
:param epochs: Number of epochs fine-tuning.
"""
self.init_results_tracking(epochs)
self._reset_amp_state()
self._reset_optim_state()
self._reset_epoch()
self.reset_training_step()
self._reset_best_wer()
def serialize_state(self, epoch, iteration):
model = remove_parallel_wrapper(self.model)
model_dict = model.serialize_state()
training_dict = self._serialize_training_state(epoch=epoch,
iteration=iteration)
results_dict = self.result_state.serialize_state()
# Ensure flat structure for backwards compatibility
state_dict = {**model_dict, **training_dict, **results_dict} # Combine dicts
return state_dict
def _serialize_training_state(self, epoch, iteration):
return {
'optim_dict': self.optim_state,
'amp': self.amp_state,
'avg_loss': self.avg_loss,
'best_wer': self.best_wer,
'epoch': epoch + 1, # increment for readability
'iteration': iteration,
}
@classmethod
def load_state(cls, state_path):
print("Loading state from model %s" % state_path)
state = torch.load(state_path, map_location=lambda storage, loc: storage)
model = DeepSpeech.load_model_package(state)
optim_state = state['optim_dict']
amp_state = state['amp']
epoch = int(state.get('epoch', 1)) - 1 # Index start at 0 for training
training_step = state.get('iteration', None)
if training_step is None:
epoch += 1 # We saved model after epoch finished, start at the next epoch.
training_step = 0
else:
training_step += 1
avg_loss = int(state.get('avg_loss', 0))
loss_results = state['loss_results']
cer_results = state['cer_results']
wer_results = state['wer_results']
best_wer = state.get('best_wer')
result_state = ResultState(loss_results=loss_results,
cer_results=cer_results,
wer_results=wer_results)
return cls(optim_state=optim_state,
amp_state=amp_state,
model=model,
result_state=result_state,
best_wer=best_wer,
avg_loss=avg_loss,
epoch=epoch,
training_step=training_step)
def set_epoch(self, epoch):
self.epoch = epoch
def set_best_wer(self, wer):
self.best_wer = wer
def set_training_step(self, training_step):
self.training_step = training_step
def reset_training_step(self):
self.training_step = 0
def reset_avg_loss(self):
self.avg_loss = 0
def _reset_amp_state(self):
self.amp_state = None
def _reset_optim_state(self):
self.optim_state = None
def _reset_epoch(self):
self.epoch = 0
def _reset_best_wer(self):
self.best_wer = None