From 7837feb54aa1d3bcf8535544401c259725b24b4b Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Mon, 19 Jun 2017 13:31:53 -0400 Subject: [PATCH 01/10] Add softmax during evaluation --- model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/model.py b/model.py index f3eb2237..9b30dd0a 100644 --- a/model.py +++ b/model.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F supported_rnns = { 'lstm': nn.LSTM, @@ -35,6 +36,12 @@ def __repr__(self): tmpstr += ')' return tmpstr +class BatchSoftmax(nn.Module): + def forward(self, input_): + output_ = input_.transpose(0,1) + batch_size = output_.size()[0] + output_ = torch.stack([F.log_softmax(output_[i]) for i in range(batch_size)], 0) + return output_ class BatchRNN(nn.Module): def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True): @@ -70,6 +77,7 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer self._rnn_type = rnn_type self._audio_conf = audio_conf or {} self._labels = labels + self._softmax = BatchSoftmax() sample_rate = self._audio_conf.get("sample_rate", 16000) window_size = self._audio_conf.get("window_size", 0.02) @@ -116,7 +124,8 @@ def forward(self, x): x = self.rnns(x) x = self.fc(x) - x = x.transpose(0, 1) # Transpose for multi-gpu concat + if not self.training: + x = self._softmax(x) return x @classmethod From b82ff155ae12a7e28598855dd59c895e89cf6e65 Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Mon, 19 Jun 2017 13:32:40 -0400 Subject: [PATCH 02/10] Rename ArgMaxDecoder and add initial BeamCTCDecoder --- decoder.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/decoder.py b/decoder.py index 62d32e33..1ee47cb8 100644 --- a/decoder.py +++ b/decoder.py @@ -42,13 +42,13 @@ def convert_to_strings(self, sequences, sizes=None): """Given a list of numeric sequences, returns the corresponding strings""" strings = [] for x in xrange(len(sequences)): - string = self.convert_to_string(sequences[x]) - string = string[0:int(sizes.data[x])] if sizes is not None else string + seq_len = sizes[x][0] if sizes is not None else len(sequences[x]) + string = self._convert_to_string(sequences[x], seq_len) strings.append(string) return strings - def convert_to_string(self, sequence): - return ''.join([self.int_to_char[i] for i in sequence]) + def _convert_to_string(self, sequence, sizes): + return ''.join([self.int_to_char[sequence[i]] for i in range(sizes)]) def process_strings(self, sequences, remove_repetitions=False): """ @@ -126,7 +126,24 @@ def decode(self, probs, sizes=None): raise NotImplementedError -class ArgMaxDecoder(Decoder): +class BeamCTCDecoder(Decoder): + def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28): + super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index) + self._beam_width = beam_width + self._top_n = top_paths + try: + import pytorch_ctc + except ImportError: + raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") + self._ctc = pytorch_ctc + + def decode(self, probs, sizes=None): + out, conf, seq_len = self._ctc.beam_decode(probs, sizes, top_paths=self._top_n, + beam_width=self._beam_width, merge_repeated=False) + strings = self.convert_to_strings(out[0], sizes=seq_len) + return self.process_strings(strings) + +class GreedyDecoder(Decoder): def decode(self, probs, sizes=None): """ Returns the argmax decoding given the probability matrix. Removes @@ -139,5 +156,6 @@ def decode(self, probs, sizes=None): strings: sequences of the model's best guess for the transcription on inputs """ _, max_probs = torch.max(probs.transpose(0, 1), 2) - strings = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes) + size_data = sizes.data if sizes is not None else None + strings = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), size_data) return self.process_strings(strings, remove_repetitions=True) From e75bf9ef6328f2899cb6885372e6e4dc42b99c80 Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Mon, 19 Jun 2017 15:45:52 -0400 Subject: [PATCH 03/10] Integrate beam search into test/predict scripts --- decoder.py | 8 ++++---- predict.py | 18 ++++++++++++++++-- test.py | 13 ++++++++++--- train.py | 4 ++-- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/decoder.py b/decoder.py index 1ee47cb8..018f821d 100644 --- a/decoder.py +++ b/decoder.py @@ -42,7 +42,7 @@ def convert_to_strings(self, sequences, sizes=None): """Given a list of numeric sequences, returns the corresponding strings""" strings = [] for x in xrange(len(sequences)): - seq_len = sizes[x][0] if sizes is not None else len(sequences[x]) + seq_len = sizes[x] if sizes is not None else len(sequences[x]) string = self._convert_to_string(sequences[x], seq_len) strings.append(string) return strings @@ -140,7 +140,8 @@ def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_inde def decode(self, probs, sizes=None): out, conf, seq_len = self._ctc.beam_decode(probs, sizes, top_paths=self._top_n, beam_width=self._beam_width, merge_repeated=False) - strings = self.convert_to_strings(out[0], sizes=seq_len) + # TODO: support returning multiple paths + strings = self.convert_to_strings(out[0], sizes=seq_len[0]) return self.process_strings(strings) class GreedyDecoder(Decoder): @@ -156,6 +157,5 @@ def decode(self, probs, sizes=None): strings: sequences of the model's best guess for the transcription on inputs """ _, max_probs = torch.max(probs.transpose(0, 1), 2) - size_data = sizes.data if sizes is not None else None - strings = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), size_data) + strings = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes) return self.process_strings(strings, remove_repetitions=True) diff --git a/predict.py b/predict.py index 5af646d9..6063f787 100644 --- a/predict.py +++ b/predict.py @@ -1,10 +1,12 @@ import argparse +import sys +import time import torch from torch.autograd import Variable from data.data_loader import SpectrogramParser -from decoder import ArgMaxDecoder +from decoder import GreedyDecoder, BeamCTCDecoder from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -13,6 +15,9 @@ parser.add_argument('--audio_path', default='audio.wav', help='Audio file to predict on') parser.add_argument('--cuda', action="store_true", help='Use cuda to test model') +parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") +beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") +beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') args = parser.parse_args() if __name__ == '__main__': @@ -22,11 +27,20 @@ labels = DeepSpeech.get_labels(model) audio_conf = DeepSpeech.get_audio_conf(model) - decoder = ArgMaxDecoder(labels) + if args.decoder == "beam": + decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, blank_index=labels.index('_')) + else: + decoder = GreedyDecoder(labels, blank_index=labels.index('_')) + parser = SpectrogramParser(audio_conf, normalize=True) + + t0 = time.time() spect = parser.parse_audio(args.audio_path).contiguous() spect = spect.view(1, 1, spect.size(0), spect.size(1)) out = model(Variable(spect, volatile=True)) out = out.transpose(0, 1) # TxNxH decoded_output = decoder.decode(out.data) + t1 = time.time() + print(decoded_output[0]) + print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3)*audio_conf['window_stride'], t1-t0), file=sys.stderr) diff --git a/test.py b/test.py index 28118689..646429bb 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ from torch.autograd import Variable from data.data_loader import SpectrogramDataset, AudioDataLoader -from decoder import ArgMaxDecoder +from decoder import GreedyDecoder, BeamCTCDecoder from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -16,6 +16,9 @@ help='path to validation manifest csv', default='data/test_manifest.csv') parser.add_argument('--batch_size', default=20, type=int, help='Batch size for training') parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading') +parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") +beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") +beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') args = parser.parse_args() if __name__ == '__main__': @@ -24,7 +27,11 @@ labels = DeepSpeech.get_labels(model) audio_conf = DeepSpeech.get_audio_conf(model) - decoder = ArgMaxDecoder(labels) + + if args.decoder == "beam": + decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, blank_index=labels.index('_')) + else: + decoder = GreedyDecoder(labels, blank_index=labels.index('_')) test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.test_manifest, labels=labels, normalize=True) @@ -49,7 +56,7 @@ out = model(inputs) out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) - sizes = Variable(input_percentages.mul_(int(seq_length)).int(), volatile=True) + sizes = input_percentages.mul_(int(seq_length)).int() decoded_output = decoder.decode(out.data, sizes) target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets)) diff --git a/train.py b/train.py index f6b1ab86..a4f86324 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ from data.bucketing_sampler import BucketingSampler, SpectrogramDatasetWithLength from data.data_loader import AudioDataLoader, SpectrogramDataset -from decoder import ArgMaxDecoder +from decoder import GreedyDecoder from model import DeepSpeech, supported_rnns parser = argparse.ArgumentParser(description='DeepSpeech training') @@ -304,7 +304,7 @@ def main(): out = model(inputs) out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) - sizes = Variable(input_percentages.mul_(int(seq_length)).int(), volatile=True) + sizes = input_percentages.mul_(int(seq_length)).int() decoded_output = decoder.decode(out.data, sizes) target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets)) From 3eac8cb1f90fe5a92556406e922801b1ad2d9a01 Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Tue, 20 Jun 2017 01:31:19 -0400 Subject: [PATCH 04/10] Fix minor bugs --- decoder.py | 2 +- train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/decoder.py b/decoder.py index 018f821d..a516597b 100644 --- a/decoder.py +++ b/decoder.py @@ -138,7 +138,7 @@ def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_inde self._ctc = pytorch_ctc def decode(self, probs, sizes=None): - out, conf, seq_len = self._ctc.beam_decode(probs, sizes, top_paths=self._top_n, + out, conf, seq_len = self._ctc.beam_decode(probs.cpu(), sizes.cpu(), top_paths=self._top_n, beam_width=self._beam_width, merge_repeated=False) # TODO: support returning multiple paths strings = self.convert_to_strings(out[0], sizes=seq_len[0]) diff --git a/train.py b/train.py index a4f86324..5015bb6d 100644 --- a/train.py +++ b/train.py @@ -153,7 +153,7 @@ def main(): parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, nesterov=True) - decoder = ArgMaxDecoder(labels) + decoder = GreedyDecoder(labels) if args.continue_from: print("Loading checkpoint model %s" % args.continue_from) From 5c19b985d3959580b98a62e0bb52235b5468178c Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Tue, 20 Jun 2017 22:30:31 -0400 Subject: [PATCH 05/10] Add alternative softmax implementation --- model.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/model.py b/model.py index 9b30dd0a..e36fc055 100644 --- a/model.py +++ b/model.py @@ -36,12 +36,15 @@ def __repr__(self): tmpstr += ')' return tmpstr -class BatchSoftmax(nn.Module): + +class InferenceBatchSoftmax(nn.Module): def forward(self, input_): - output_ = input_.transpose(0,1) - batch_size = output_.size()[0] - output_ = torch.stack([F.log_softmax(output_[i]) for i in range(batch_size)], 0) - return output_ + if not self.training: + batch_size = input_.size()[0] + return torch.stack([F.log_softmax(input_[i]) for i in range(batch_size)], 0) + else: + return input_ + class BatchRNN(nn.Module): def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True): @@ -77,7 +80,6 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer self._rnn_type = rnn_type self._audio_conf = audio_conf or {} self._labels = labels - self._softmax = BatchSoftmax() sample_rate = self._audio_conf.get("sample_rate", 16000) window_size = self._audio_conf.get("window_size", 0.02) @@ -113,6 +115,7 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer self.fc = nn.Sequential( SequenceWise(fully_connected), ) + self.softmax = InferenceBatchSoftmax() def forward(self, x): x = self.conv(x) @@ -124,8 +127,8 @@ def forward(self, x): x = self.rnns(x) x = self.fc(x) - if not self.training: - x = self._softmax(x) + x = x.transpose(0, 1) + x = self.softmax(x) return x @classmethod From acf867009a322767d3b320345da638ed7ed95c3a Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Tue, 20 Jun 2017 22:34:21 -0400 Subject: [PATCH 06/10] Add decoder changes --- decoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/decoder.py b/decoder.py index a516597b..81f31a62 100644 --- a/decoder.py +++ b/decoder.py @@ -138,7 +138,8 @@ def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_inde self._ctc = pytorch_ctc def decode(self, probs, sizes=None): - out, conf, seq_len = self._ctc.beam_decode(probs.cpu(), sizes.cpu(), top_paths=self._top_n, + sizes = sizes.cpu() if sizes is not None else None + out, conf, seq_len = self._ctc.beam_decode(probs.cpu(), sizes, top_paths=self._top_n, beam_width=self._beam_width, merge_repeated=False) # TODO: support returning multiple paths strings = self.convert_to_strings(out[0], sizes=seq_len[0]) From b474841757a0e388a22f453886febd1b8089a350 Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Sat, 24 Jun 2017 23:19:06 -0400 Subject: [PATCH 07/10] Integrate latest pytorch_ctc library --- decoder.py | 18 +++++++++++++----- predict.py | 19 ++++++++++++++++--- test.py | 19 ++++++++++++++++--- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/decoder.py b/decoder.py index 81f31a62..b6c548b0 100644 --- a/decoder.py +++ b/decoder.py @@ -17,8 +17,13 @@ import Levenshtein as Lev import torch +from enum import Enum from six.moves import xrange - +try: + from pytorch_ctc import CTCBeamDecoder as CTCBD + from pytorch_ctc import Scorer, KenLMScorer +except ImportError: + print("warn: pytorch_ctc unavailable. Only greedy decoding is supported.") class Decoder(object): """ @@ -127,7 +132,7 @@ def decode(self, probs, sizes=None): class BeamCTCDecoder(Decoder): - def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28): + def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, space_index=28): super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index) self._beam_width = beam_width self._top_n = top_paths @@ -135,12 +140,15 @@ def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_inde import pytorch_ctc except ImportError: raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") - self._ctc = pytorch_ctc + + self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width, + blank_index=blank_index, space_index=space_index, merge_repeated=True) + def decode(self, probs, sizes=None): sizes = sizes.cpu() if sizes is not None else None - out, conf, seq_len = self._ctc.beam_decode(probs.cpu(), sizes, top_paths=self._top_n, - beam_width=self._beam_width, merge_repeated=False) + out, conf, seq_len = self._decoder.decode(probs.cpu(), sizes) + # TODO: support returning multiple paths strings = self.convert_to_strings(out[0], sizes=seq_len[0]) return self.process_strings(strings) diff --git a/predict.py b/predict.py index 6063f787..56c23760 100644 --- a/predict.py +++ b/predict.py @@ -6,7 +6,7 @@ from torch.autograd import Variable from data.data_loader import SpectrogramParser -from decoder import GreedyDecoder, BeamCTCDecoder +from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -18,6 +18,11 @@ parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') +beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') +beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') +beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight') +beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)') +beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)') args = parser.parse_args() if __name__ == '__main__': @@ -28,9 +33,17 @@ audio_conf = DeepSpeech.get_audio_conf(model) if args.decoder == "beam": - decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, blank_index=labels.index('_')) + scorer = None + if args.lm_path is not None: + scorer = KenLMScorer(labels, args.lm_path, args.trie_path) + scorer.set_lm_weight(args.lm_alpha) + scorer.set_word_weight(args.lm_beta1) + scorer.set_valid_word_weight(args.lm_beta2) + else: + scorer = Scorer() + decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_')) else: - decoder = GreedyDecoder(labels, blank_index=labels.index('_')) + decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_')) parser = SpectrogramParser(audio_conf, normalize=True) diff --git a/test.py b/test.py index 646429bb..ec52f01f 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ from torch.autograd import Variable from data.data_loader import SpectrogramDataset, AudioDataLoader -from decoder import GreedyDecoder, BeamCTCDecoder +from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -19,6 +19,11 @@ parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') +beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') +beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') +beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight') +beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)') +beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)') args = parser.parse_args() if __name__ == '__main__': @@ -29,9 +34,17 @@ audio_conf = DeepSpeech.get_audio_conf(model) if args.decoder == "beam": - decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, blank_index=labels.index('_')) + scorer = None + if args.lm_path is not None: + scorer = KenLMScorer(labels, args.lm_path, args.trie_path) + scorer.set_lm_weight(args.lm_alpha) + scorer.set_word_weight(args.lm_beta1) + scorer.set_valid_word_weight(args.lm_beta2) + else: + scorer = Scorer() + decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_')) else: - decoder = GreedyDecoder(labels, blank_index=labels.index('_')) + decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_')) test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.test_manifest, labels=labels, normalize=True) From 4848334e1b0f656f2fb47c73d02982e2f45b5498 Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Sun, 25 Jun 2017 20:02:39 -0400 Subject: [PATCH 08/10] Change merge parameter --- decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decoder.py b/decoder.py index b6c548b0..8b894ac8 100644 --- a/decoder.py +++ b/decoder.py @@ -142,7 +142,7 @@ def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, sp raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width, - blank_index=blank_index, space_index=space_index, merge_repeated=True) + blank_index=blank_index, space_index=space_index, merge_repeated=False) def decode(self, probs, sizes=None): From 7f0914167c438b51245e6a55333e5eaa7994b27c Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Mon, 26 Jun 2017 11:42:24 -0400 Subject: [PATCH 09/10] Add utility script for creating LM trie --- generate_lm_trie.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 generate_lm_trie.py diff --git a/generate_lm_trie.py b/generate_lm_trie.py new file mode 100644 index 00000000..9f9d41b6 --- /dev/null +++ b/generate_lm_trie.py @@ -0,0 +1,21 @@ +import pytorch_ctc +import json +import argparse + +parser = argparse.ArgumentParser(description='LM Trie Generation') +parser.add_argument('--labels', help='path to label json file', default='labels.json') +parser.add_argument('--dictionary', help='path to text dictionary (one word per line)', default='vocab.txt') +parser.add_argument('--kenlm', help='path to binary kenlm language model', default="lm.kenlm") +parser.add_argument('--trie', help='path of trie to output', default='vocab.trie') + +def main(): + args = parser.parse_args() + with open(args.labels, "r") as fh: + label_data = json.load(fh) + + labels = ''.join(label_data) + + pytorch_ctc.generate_lm_trie(args.dictionary, args.kenlm, args.trie, labels, labels.index('_'), labels.index(' ')) + +if __name__ == '__main__': + main() From 617edf3619343dd6474b8045bd559c04f68979aa Mon Sep 17 00:00:00 2001 From: Ryan Leary Date: Mon, 26 Jun 2017 13:14:52 -0400 Subject: [PATCH 10/10] Update README --- README.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/README.md b/README.md index ad8f22e3..e5fa924d 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,16 @@ cd audio python setup.py install ``` +If you want decoding to support beam search with an optional language model, install pytorch-ctc: +``` +git clone --recursive https://github.com/ryanleary/pytorch-ctc.git +cd pytorch-ctc +pip install -r requirements.txt + +# build the extension and install python package (requires gcc-5 or later) +CC=/path/to/gcc-5 CXX=/path/to/g++-5 python setup.py install +``` + Finally: ``` pip install -r requirements.txt @@ -220,6 +230,17 @@ An example script to output a prediction has been provided: python predict.py --model_path models/deepspeech.pth.tar --audio_path /path/to/audio.wav ``` +### Alternate Decoders +By default, `test.py` and `predict.py` use a `GreedyDecoder` which picks the highest-likelihood output label at each timestep. Repeated and blank symbols are then filtered to give the final output. + +A beam search decoder can optionally be used with the installation of the `pytorch-ctc` library as described in the Installation section. The `test` and `predict` scripts have a `--decoder` argument. To use the beam decoder, add `--decoder beam`. The beam decoder enables additional decoding parameters: +- **beam_width** how many beams to consider at each timestep +- **lm_path** optional binary KenLM language model to use for decoding +- **trie_path** trie describing lexicon. required if `lm_path` is supplied +- **lm_alpha** weight for language model +- **lm_beta1** bonus weight for words +- **lm_beta2** bonus weight for in-vocabulary words + ## Acknowledgements Thanks to [Egor](https://github.com/EgorLakomkin) and [Ryan](https://github.com/ryanleary) for their contributions!