From 671ac5a4ceb8980403554da94feb995161fb5fef Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Tue, 14 Mar 2023 00:34:09 +0100 Subject: [PATCH] Fix alignment between the segments and the list of words (#1087) * Fix alignment between the segments and the list of words * Ensure the word index does not overflow --- whisper/timing.py | 52 +++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/whisper/timing.py b/whisper/timing.py index 7bc2b9a6..1f8f4cf8 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -1,3 +1,4 @@ +import itertools import subprocess import warnings from dataclasses import dataclass @@ -290,34 +291,41 @@ def add_word_timestamps( if len(segments) == 0: return - text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot] + text_tokens_per_segment = [ + [token for token in segment["tokens"] if token < tokenizer.eot] + for segment in segments + ] + + text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE - segment_lengths = [len(s["tokens"]) for s in segments] - token_sources = np.repeat(np.arange(len(segments)), segment_lengths) - - for segment in segments: - segment["words"] = [] - - word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0)) - for i, timing in enumerate(alignment): - if timing.word: - segment = segments[token_sources[word_boundaries[i]]] - start = round(time_offset + timing.start, 2) - end = round(time_offset + timing.end, 2) - segment["words"].append( - dict( - word=timing.word, - start=start, - end=end, - probability=timing.probability, + word_index = 0 + + for segment, text_tokens in zip(segments, text_tokens_per_segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignment) and saved_tokens < len(text_tokens): + timing = alignment[word_index] + + if timing.word: + words.append( + dict( + word=timing.word, + start=round(time_offset + timing.start, 2), + end=round(time_offset + timing.end, 2), + probability=timing.probability, + ) ) - ) - for segment in segments: - if len(words := segment["words"]) > 0: + saved_tokens += len(timing.tokens) + word_index += 1 + + if len(words) > 0: # adjust the segment-level timestamps based on the word-level timestamps segment["start"] = words[0]["start"] segment["end"] = words[-1]["end"] + + segment["words"] = words