Skip to content

Commit

Permalink
Formatted decode script, added dels to benchmark and updated train to…
Browse files Browse the repository at this point in the history
… save best model
  • Loading branch information
SeanNaren committed Jun 29, 2017
1 parent fb9452c commit 8439586
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
2 changes: 2 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def iteration(input_data):
optimizer.step()
torch.cuda.synchronize()
end = time.time()
del loss
del out
return start, end


Expand Down
4 changes: 3 additions & 1 deletion decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import torch
from enum import Enum
from six.moves import xrange

try:
from pytorch_ctc import CTCBeamDecoder as CTCBD
from pytorch_ctc import Scorer, KenLMScorer
except ImportError:
print("warn: pytorch_ctc unavailable. Only greedy decoding is supported.")


class Decoder(object):
"""
Basic decoder class from which all other decoders inherit. Implements several
Expand Down Expand Up @@ -144,7 +146,6 @@ def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, sp
self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width,
blank_index=blank_index, space_index=space_index, merge_repeated=False)


def decode(self, probs, sizes=None):
sizes = sizes.cpu() if sizes is not None else None
out, conf, seq_len = self._decoder.decode(probs.cpu(), sizes)
Expand All @@ -153,6 +154,7 @@ def decode(self, probs, sizes=None):
strings = self.convert_to_strings(out[0], sizes=seq_len[0])
return self.process_strings(strings)


class GreedyDecoder(Decoder):
def decode(self, probs, sizes=None):
"""
Expand Down
2 changes: 2 additions & 0 deletions generate_lm_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
parser.add_argument('--kenlm', help='path to binary kenlm language model', default="lm.kenlm")
parser.add_argument('--trie', help='path of trie to output', default='vocab.trie')


def main():
args = parser.parse_args()
with open(args.labels, "r") as fh:
Expand All @@ -17,5 +18,6 @@ def main():

pytorch_ctc.generate_lm_trie(args.dictionary, args.kenlm, args.trie, labels, labels.index('_'), labels.index(' '))


if __name__ == '__main__':
main()
25 changes: 17 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
parser.add_argument('--checkpoint', dest='checkpoint', action='store_true', help='Enables checkpoint saving of model')
parser.add_argument('--checkpoint_per_batch', default=0, type=int, help='Save checkpoint per batch. 0 means never save')
parser.add_argument('--visdom', dest='visdom', action='store_true', help='Turn on visdom graphing')
parser.add_argument('--visdom_id', default='Deepspeech training', help='Identifier for visdom graph')
parser.add_argument('--save_folder', default='models/', help='Location to save epoch models')
parser.add_argument('--final_model_path', default='models/deepspeech_final.pth.tar',
help='Location to save final model')
parser.add_argument('--model_path', default='models/deepspeech_final.pth.tar',
help='Location to save best validation model')
parser.add_argument('--continue_from', default='', help='Continue from checkpoint model')
parser.add_argument('--rnn_type', default='lstm', help='Type of the RNN. rnn|gru|lstm are supported')
parser.add_argument('--augment', dest='augment', action='store_true', help='Use random tempo and gain perturbations.')
Expand All @@ -57,6 +58,8 @@
help='Turn off bucketing and sample from dataset based on sequence length (smallest to largest)')
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, augment=False, tensorboard=False,
log_params=False, no_bucketing=False)


def to_np(x):
return x.data.cpu().numpy()

Expand Down Expand Up @@ -86,13 +89,14 @@ def main():

loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor(
args.epochs)
best_wer = None
if args.visdom:
from visdom import Visdom
viz = Visdom()

opts = [dict(title='Loss', ylabel='Loss', xlabel='Epoch'),
dict(title='WER', ylabel='WER', xlabel='Epoch'),
dict(title='CER', ylabel='CER', xlabel='Epoch')]
opts = [dict(title=args.visdom_id + ' Loss', ylabel='Loss', xlabel='Epoch'),
dict(title=args.visdom_id + ' WER', ylabel='WER', xlabel='Epoch'),
dict(title=args.visdom_id + ' CER', ylabel='CER', xlabel='Epoch')]

viz_windows = [None, None, None]
epochs = torch.arange(1, args.epochs + 1)
Expand Down Expand Up @@ -169,7 +173,7 @@ def main():
start_iter += 1
avg_loss = int(package.get('avg_loss', 0))
loss_results, cer_results, wer_results = package['loss_results'], package[
'cer_results'], package['wer_results']
'cer_results'], package['wer_results']
if args.visdom and \
package['loss_results'] is not None and start_epoch > 0: # Add previous scores to visdom graph
x_axis = epochs[0:start_epoch]
Expand Down Expand Up @@ -372,6 +376,13 @@ def main():
optimizer.load_state_dict(optim_state)
print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr']))

if best_wer is None or best_wer > wer:
print("Found better validated model, saving to %s" % args.final_model_path)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results)
, args.final_model_path)
best_wer = wer

avg_loss = 0
if not args.no_bucketing and epoch == 0:
print("Switching to bucketing sampler for following epochs")
Expand All @@ -381,8 +392,6 @@ def main():
sampler = BucketingSampler(train_dataset)
train_loader.sampler = sampler

torch.save(DeepSpeech.serialize(model, optimizer=optimizer), args.final_model_path)


if __name__ == '__main__':
main()

0 comments on commit 8439586

Please sign in to comment.