Skip to content

Commit

Permalink
Fix infinite loop caused by incorrect timestamp tokens prediction (op…
Browse files Browse the repository at this point in the history
…enai#914)

* Fix infinite loop caused by incorrect timestamp tokens prediction

openai#810

* Update decoding.py

---------

Co-authored-by: Jong Wook Kim <[email protected]>
  • Loading branch information
2 people authored and ilanit1997 committed May 16, 2023
1 parent 387da57 commit 1d56e13
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ def apply(self, logits: Tensor, tokens: Tensor):

# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin

Expand All @@ -426,6 +427,11 @@ def apply(self, logits: Tensor, tokens: Tensor):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf

timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf

if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
Expand Down

0 comments on commit 1d56e13

Please sign in to comment.