Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoding improvements #1033

Merged
merged 2 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,13 @@ def _get_suppress_tokens(self) -> Tuple[int]:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"

suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
[
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm
]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
Expand Down
8 changes: 8 additions & 0 deletions whisper/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ def decode_with_timestamps(self, tokens) -> str:
def eot(self) -> int:
return self.tokenizer.eos_token_id

@cached_property
def transcribe(self) -> int:
return self._get_single_token_id("<|transcribe|>")

@cached_property
def translate(self) -> int:
return self._get_single_token_id("<|translate|>")

@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
Expand Down
30 changes: 15 additions & 15 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,35 +197,35 @@ def add_segment(
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
consecutive = consecutive.tolist() + [len(tokens)]
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
start=timestamp_offset + start_timestamp_pos * time_precision,
end=timestamp_offset + end_timestamp_pos * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
if ended_with_single_timestamp:
# single timestamp at the end means no speech after the last timestamp.
seek += segment.shape[-1]
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_pos * time_precision

add_segment(
start=timestamp_offset,
Expand Down