diff --git a/README.md b/README.md index ad8f22e3..e5fa924d 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,16 @@ cd audio python setup.py install ``` +If you want decoding to support beam search with an optional language model, install pytorch-ctc: +``` +git clone --recursive https://github.com/ryanleary/pytorch-ctc.git +cd pytorch-ctc +pip install -r requirements.txt + +# build the extension and install python package (requires gcc-5 or later) +CC=/path/to/gcc-5 CXX=/path/to/g++-5 python setup.py install +``` + Finally: ``` pip install -r requirements.txt @@ -220,6 +230,17 @@ An example script to output a prediction has been provided: python predict.py --model_path models/deepspeech.pth.tar --audio_path /path/to/audio.wav ``` +### Alternate Decoders +By default, `test.py` and `predict.py` use a `GreedyDecoder` which picks the highest-likelihood output label at each timestep. Repeated and blank symbols are then filtered to give the final output. + +A beam search decoder can optionally be used with the installation of the `pytorch-ctc` library as described in the Installation section. The `test` and `predict` scripts have a `--decoder` argument. To use the beam decoder, add `--decoder beam`. The beam decoder enables additional decoding parameters: +- **beam_width** how many beams to consider at each timestep +- **lm_path** optional binary KenLM language model to use for decoding +- **trie_path** trie describing lexicon. required if `lm_path` is supplied +- **lm_alpha** weight for language model +- **lm_beta1** bonus weight for words +- **lm_beta2** bonus weight for in-vocabulary words + ## Acknowledgements Thanks to [Egor](https://github.com/EgorLakomkin) and [Ryan](https://github.com/ryanleary) for their contributions! diff --git a/decoder.py b/decoder.py index 62d32e33..8b894ac8 100644 --- a/decoder.py +++ b/decoder.py @@ -17,8 +17,13 @@ import Levenshtein as Lev import torch +from enum import Enum from six.moves import xrange - +try: + from pytorch_ctc import CTCBeamDecoder as CTCBD + from pytorch_ctc import Scorer, KenLMScorer +except ImportError: + print("warn: pytorch_ctc unavailable. Only greedy decoding is supported.") class Decoder(object): """ @@ -42,13 +47,13 @@ def convert_to_strings(self, sequences, sizes=None): """Given a list of numeric sequences, returns the corresponding strings""" strings = [] for x in xrange(len(sequences)): - string = self.convert_to_string(sequences[x]) - string = string[0:int(sizes.data[x])] if sizes is not None else string + seq_len = sizes[x] if sizes is not None else len(sequences[x]) + string = self._convert_to_string(sequences[x], seq_len) strings.append(string) return strings - def convert_to_string(self, sequence): - return ''.join([self.int_to_char[i] for i in sequence]) + def _convert_to_string(self, sequence, sizes): + return ''.join([self.int_to_char[sequence[i]] for i in range(sizes)]) def process_strings(self, sequences, remove_repetitions=False): """ @@ -126,7 +131,29 @@ def decode(self, probs, sizes=None): raise NotImplementedError -class ArgMaxDecoder(Decoder): +class BeamCTCDecoder(Decoder): + def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, space_index=28): + super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index) + self._beam_width = beam_width + self._top_n = top_paths + try: + import pytorch_ctc + except ImportError: + raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") + + self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width, + blank_index=blank_index, space_index=space_index, merge_repeated=False) + + + def decode(self, probs, sizes=None): + sizes = sizes.cpu() if sizes is not None else None + out, conf, seq_len = self._decoder.decode(probs.cpu(), sizes) + + # TODO: support returning multiple paths + strings = self.convert_to_strings(out[0], sizes=seq_len[0]) + return self.process_strings(strings) + +class GreedyDecoder(Decoder): def decode(self, probs, sizes=None): """ Returns the argmax decoding given the probability matrix. Removes diff --git a/generate_lm_trie.py b/generate_lm_trie.py new file mode 100644 index 00000000..9f9d41b6 --- /dev/null +++ b/generate_lm_trie.py @@ -0,0 +1,21 @@ +import pytorch_ctc +import json +import argparse + +parser = argparse.ArgumentParser(description='LM Trie Generation') +parser.add_argument('--labels', help='path to label json file', default='labels.json') +parser.add_argument('--dictionary', help='path to text dictionary (one word per line)', default='vocab.txt') +parser.add_argument('--kenlm', help='path to binary kenlm language model', default="lm.kenlm") +parser.add_argument('--trie', help='path of trie to output', default='vocab.trie') + +def main(): + args = parser.parse_args() + with open(args.labels, "r") as fh: + label_data = json.load(fh) + + labels = ''.join(label_data) + + pytorch_ctc.generate_lm_trie(args.dictionary, args.kenlm, args.trie, labels, labels.index('_'), labels.index(' ')) + +if __name__ == '__main__': + main() diff --git a/model.py b/model.py index f3eb2237..e36fc055 100644 --- a/model.py +++ b/model.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F supported_rnns = { 'lstm': nn.LSTM, @@ -36,6 +37,15 @@ def __repr__(self): return tmpstr +class InferenceBatchSoftmax(nn.Module): + def forward(self, input_): + if not self.training: + batch_size = input_.size()[0] + return torch.stack([F.log_softmax(input_[i]) for i in range(batch_size)], 0) + else: + return input_ + + class BatchRNN(nn.Module): def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True): super(BatchRNN, self).__init__() @@ -105,6 +115,7 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer self.fc = nn.Sequential( SequenceWise(fully_connected), ) + self.softmax = InferenceBatchSoftmax() def forward(self, x): x = self.conv(x) @@ -116,7 +127,8 @@ def forward(self, x): x = self.rnns(x) x = self.fc(x) - x = x.transpose(0, 1) # Transpose for multi-gpu concat + x = x.transpose(0, 1) + x = self.softmax(x) return x @classmethod diff --git a/predict.py b/predict.py index 5af646d9..56c23760 100644 --- a/predict.py +++ b/predict.py @@ -1,10 +1,12 @@ import argparse +import sys +import time import torch from torch.autograd import Variable from data.data_loader import SpectrogramParser -from decoder import ArgMaxDecoder +from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -13,6 +15,14 @@ parser.add_argument('--audio_path', default='audio.wav', help='Audio file to predict on') parser.add_argument('--cuda', action="store_true", help='Use cuda to test model') +parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") +beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") +beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') +beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') +beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') +beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight') +beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)') +beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)') args = parser.parse_args() if __name__ == '__main__': @@ -22,11 +32,28 @@ labels = DeepSpeech.get_labels(model) audio_conf = DeepSpeech.get_audio_conf(model) - decoder = ArgMaxDecoder(labels) + if args.decoder == "beam": + scorer = None + if args.lm_path is not None: + scorer = KenLMScorer(labels, args.lm_path, args.trie_path) + scorer.set_lm_weight(args.lm_alpha) + scorer.set_word_weight(args.lm_beta1) + scorer.set_valid_word_weight(args.lm_beta2) + else: + scorer = Scorer() + decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_')) + else: + decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_')) + parser = SpectrogramParser(audio_conf, normalize=True) + + t0 = time.time() spect = parser.parse_audio(args.audio_path).contiguous() spect = spect.view(1, 1, spect.size(0), spect.size(1)) out = model(Variable(spect, volatile=True)) out = out.transpose(0, 1) # TxNxH decoded_output = decoder.decode(out.data) + t1 = time.time() + print(decoded_output[0]) + print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3)*audio_conf['window_stride'], t1-t0), file=sys.stderr) diff --git a/test.py b/test.py index 28118689..ec52f01f 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ from torch.autograd import Variable from data.data_loader import SpectrogramDataset, AudioDataLoader -from decoder import ArgMaxDecoder +from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -16,6 +16,14 @@ help='path to validation manifest csv', default='data/test_manifest.csv') parser.add_argument('--batch_size', default=20, type=int, help='Batch size for training') parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading') +parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") +beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") +beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') +beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') +beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') +beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight') +beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)') +beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)') args = parser.parse_args() if __name__ == '__main__': @@ -24,7 +32,19 @@ labels = DeepSpeech.get_labels(model) audio_conf = DeepSpeech.get_audio_conf(model) - decoder = ArgMaxDecoder(labels) + + if args.decoder == "beam": + scorer = None + if args.lm_path is not None: + scorer = KenLMScorer(labels, args.lm_path, args.trie_path) + scorer.set_lm_weight(args.lm_alpha) + scorer.set_word_weight(args.lm_beta1) + scorer.set_valid_word_weight(args.lm_beta2) + else: + scorer = Scorer() + decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_')) + else: + decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_')) test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.test_manifest, labels=labels, normalize=True) @@ -49,7 +69,7 @@ out = model(inputs) out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) - sizes = Variable(input_percentages.mul_(int(seq_length)).int(), volatile=True) + sizes = input_percentages.mul_(int(seq_length)).int() decoded_output = decoder.decode(out.data, sizes) target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets)) diff --git a/train.py b/train.py index f6b1ab86..5015bb6d 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ from data.bucketing_sampler import BucketingSampler, SpectrogramDatasetWithLength from data.data_loader import AudioDataLoader, SpectrogramDataset -from decoder import ArgMaxDecoder +from decoder import GreedyDecoder from model import DeepSpeech, supported_rnns parser = argparse.ArgumentParser(description='DeepSpeech training') @@ -153,7 +153,7 @@ def main(): parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, nesterov=True) - decoder = ArgMaxDecoder(labels) + decoder = GreedyDecoder(labels) if args.continue_from: print("Loading checkpoint model %s" % args.continue_from) @@ -304,7 +304,7 @@ def main(): out = model(inputs) out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) - sizes = Variable(input_percentages.mul_(int(seq_length)).int(), volatile=True) + sizes = input_percentages.mul_(int(seq_length)).int() decoded_output = decoder.decode(out.data, sizes) target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets))