Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add beam search #95

Merged
merged 10 commits into from
Jun 29, 2017
Prev Previous commit
Next Next commit
Rename ArgMaxDecoder and add initial BeamCTCDecoder
  • Loading branch information
ryanleary committed Jun 28, 2017
commit b82ff155ae12a7e28598855dd59c895e89cf6e65
30 changes: 24 additions & 6 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def convert_to_strings(self, sequences, sizes=None):
"""Given a list of numeric sequences, returns the corresponding strings"""
strings = []
for x in xrange(len(sequences)):
string = self.convert_to_string(sequences[x])
string = string[0:int(sizes.data[x])] if sizes is not None else string
seq_len = sizes[x][0] if sizes is not None else len(sequences[x])
string = self._convert_to_string(sequences[x], seq_len)
strings.append(string)
return strings

def convert_to_string(self, sequence):
return ''.join([self.int_to_char[i] for i in sequence])
def _convert_to_string(self, sequence, sizes):
return ''.join([self.int_to_char[sequence[i]] for i in range(sizes)])

def process_strings(self, sequences, remove_repetitions=False):
"""
Expand Down Expand Up @@ -126,7 +126,24 @@ def decode(self, probs, sizes=None):
raise NotImplementedError


class ArgMaxDecoder(Decoder):
class BeamCTCDecoder(Decoder):
def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28):
super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index)
self._beam_width = beam_width
self._top_n = top_paths
try:
import pytorch_ctc
except ImportError:
raise ImportError("BeamCTCDecoder requires pytorch_ctc package.")
self._ctc = pytorch_ctc

def decode(self, probs, sizes=None):
out, conf, seq_len = self._ctc.beam_decode(probs, sizes, top_paths=self._top_n,
beam_width=self._beam_width, merge_repeated=False)
strings = self.convert_to_strings(out[0], sizes=seq_len)
return self.process_strings(strings)

class GreedyDecoder(Decoder):
def decode(self, probs, sizes=None):
"""
Returns the argmax decoding given the probability matrix. Removes
Expand All @@ -139,5 +156,6 @@ def decode(self, probs, sizes=None):
strings: sequences of the model's best guess for the transcription on inputs
"""
_, max_probs = torch.max(probs.transpose(0, 1), 2)
strings = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes)
size_data = sizes.data if sizes is not None else None
strings = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), size_data)
return self.process_strings(strings, remove_repetitions=True)