Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infinite loop caused by incorrect timestamp tokens prediction #914

Merged
merged 2 commits into from
Feb 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ def apply(self, logits: Tensor, tokens: Tensor):

# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin

Expand All @@ -422,6 +423,11 @@ def apply(self, logits: Tensor, tokens: Tensor):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf

timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not enough to prevent the infinite loop (see discussion #924) because it is not preventing the model to always output <|0.00|>

Suggestion:

                timestamp_last = max(timestamps[-1], self.tokenizer.timestamp_begin + 1) # Avoid to emit <|0.00|> again
                logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better suggestion is:

                if last_was_timestamp and not penultimate_was_timestamp:
                    timestamp_last = timestamps[-1]
                else:
                    timestamp_last = timestamps[-1] + 1
                logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf

to force that timestamps are strictly increasing after a speech segment / increasing between the end of a speech segment and the start of the next one.

Is anyone looking at this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great solution @Jeronymous! I checked it and it works.

With your permission, Im gonna create a new PR to speed up this change. Im gonna mention your suggestion.


if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
Expand Down