Skip to content

Commit

Permalink
Integrate latest pytorch_ctc library
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary committed Jun 28, 2017
1 parent acf8670 commit b474841
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
18 changes: 13 additions & 5 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@

import Levenshtein as Lev
import torch
from enum import Enum
from six.moves import xrange

try:
from pytorch_ctc import CTCBeamDecoder as CTCBD
from pytorch_ctc import Scorer, KenLMScorer
except ImportError:
print("warn: pytorch_ctc unavailable. Only greedy decoding is supported.")

class Decoder(object):
"""
Expand Down Expand Up @@ -127,20 +132,23 @@ def decode(self, probs, sizes=None):


class BeamCTCDecoder(Decoder):
def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28):
def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, space_index=28):
super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index)
self._beam_width = beam_width
self._top_n = top_paths
try:
import pytorch_ctc
except ImportError:
raise ImportError("BeamCTCDecoder requires pytorch_ctc package.")
self._ctc = pytorch_ctc

self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width,
blank_index=blank_index, space_index=space_index, merge_repeated=True)


def decode(self, probs, sizes=None):
sizes = sizes.cpu() if sizes is not None else None
out, conf, seq_len = self._ctc.beam_decode(probs.cpu(), sizes, top_paths=self._top_n,
beam_width=self._beam_width, merge_repeated=False)
out, conf, seq_len = self._decoder.decode(probs.cpu(), sizes)

# TODO: support returning multiple paths
strings = self.convert_to_strings(out[0], sizes=seq_len[0])
return self.process_strings(strings)
Expand Down
19 changes: 16 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.autograd import Variable

from data.data_loader import SpectrogramParser
from decoder import GreedyDecoder, BeamCTCDecoder
from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer
from model import DeepSpeech

parser = argparse.ArgumentParser(description='DeepSpeech prediction')
Expand All @@ -18,6 +18,11 @@
parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder")
beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use')
beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight')
beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)')
beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)')
args = parser.parse_args()

if __name__ == '__main__':
Expand All @@ -28,9 +33,17 @@
audio_conf = DeepSpeech.get_audio_conf(model)

if args.decoder == "beam":
decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, blank_index=labels.index('_'))
scorer = None
if args.lm_path is not None:
scorer = KenLMScorer(labels, args.lm_path, args.trie_path)
scorer.set_lm_weight(args.lm_alpha)
scorer.set_word_weight(args.lm_beta1)
scorer.set_valid_word_weight(args.lm_beta2)
else:
scorer = Scorer()
decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_'))
else:
decoder = GreedyDecoder(labels, blank_index=labels.index('_'))
decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_'))

parser = SpectrogramParser(audio_conf, normalize=True)

Expand Down
19 changes: 16 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.autograd import Variable

from data.data_loader import SpectrogramDataset, AudioDataLoader
from decoder import GreedyDecoder, BeamCTCDecoder
from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer
from model import DeepSpeech

parser = argparse.ArgumentParser(description='DeepSpeech prediction')
Expand All @@ -19,6 +19,11 @@
parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder")
beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use')
beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight')
beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)')
beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)')
args = parser.parse_args()

if __name__ == '__main__':
Expand All @@ -29,9 +34,17 @@
audio_conf = DeepSpeech.get_audio_conf(model)

if args.decoder == "beam":
decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, blank_index=labels.index('_'))
scorer = None
if args.lm_path is not None:
scorer = KenLMScorer(labels, args.lm_path, args.trie_path)
scorer.set_lm_weight(args.lm_alpha)
scorer.set_word_weight(args.lm_beta1)
scorer.set_valid_word_weight(args.lm_beta2)
else:
scorer = Scorer()
decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_'))
else:
decoder = GreedyDecoder(labels, blank_index=labels.index('_'))
decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_'))

test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.test_manifest, labels=labels,
normalize=True)
Expand Down

0 comments on commit b474841

Please sign in to comment.