From 8f9357fa99780dd480db4d88f51490a62900865a Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Fri, 20 Jan 2023 12:30:50 -0800 Subject: [PATCH 01/18] word-level timestamps in `transcribe()` --- requirements.txt | 1 + whisper/audio.py | 4 + whisper/tokenizer.py | 34 +++++++++ whisper/transcribe.py | 166 +++++++++++++++++++++++++++++++++++------- 4 files changed, 178 insertions(+), 27 deletions(-) diff --git a/requirements.txt b/requirements.txt index a4614035..3fe06aae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ tqdm more-itertools transformers>=4.19.0 ffmpeg-python==0.2.0 +dtw-python==1.3.0 diff --git a/whisper/audio.py b/whisper/audio.py index de8a1951..964d4157 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -18,6 +18,10 @@ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each) +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each) + def load_audio(file: str, sr: int = SAMPLE_RATE): """ diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index a27cb359..3cc499a3 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -1,4 +1,5 @@ import os +import string from dataclasses import dataclass from functools import lru_cache from typing import List, Optional, Tuple, Union @@ -269,6 +270,39 @@ def _get_single_token_id(self, text) -> int: assert len(tokens) == 1, f"{text} is not encoded as a single token" return tokens[0] + def split_tokens_on_unicode(self, tokens: List[int]): + words = [] + word_tokens = [] + current_tokens = [] + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + if "\ufffd" not in decoded: + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + + return words, word_tokens + + def split_tokens_on_spaces(self, tokens: List[int]): + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + if special or with_space or punctuation: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens + @lru_cache(maxsize=None) def build_tokenizer(name: str = "gpt2"): diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 62ef5fe5..98d30411 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -8,15 +8,104 @@ import torch import tqdm -from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram +from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult -from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt if TYPE_CHECKING: from .model import Whisper +def add_word_timestamps( + model: "Whisper", + tokenizer: Tokenizer, + mel: torch.Tensor, + num_frames: int, + segments: List[dict], + *, + medfilt_width: int = 7, + qk_scale: float = 1.0, +): + if len(segments) == 0: + return + + from dtw import dtw + from scipy.ndimage import median_filter + + # install hooks on the cross attention layers to retrieve the attention weights + QKs = [None] * model.dims.n_text_layer + hooks = [ + block.cross_attn.register_forward_hook( + lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1]) + ) + for i, block in enumerate(model.decoder.blocks) + ] + + tokens = torch.tensor( + [ + *tokenizer.sot_sequence, + tokenizer.timestamp_begin, + *[t for segment in segments for t in segment["tokens"]], + tokenizer.timestamp_begin + mel.shape[-1] // 2, + tokenizer.eot, + ] + ).to(model.device) + + with torch.no_grad(): + model(mel.unsqueeze(0), tokens.unsqueeze(0)) + + for hook in hooks: + hook.remove() + + weights = torch.cat(QKs) # layers * heads * tokens * frames + weights = weights[:, :, :, : num_frames // 2].cpu() + weights = median_filter(weights, (1, 1, 1, medfilt_width)) + weights = torch.tensor(weights * qk_scale).softmax(dim=-1) + + w = weights / weights.norm(dim=-2, keepdim=True) + matrix = w.mean(axis=(0, 1)) + + alignment = dtw(-matrix.double().numpy()) + + jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool) + jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND + + if tokenizer.language in {"zh", "ja", "th", "lo", "my"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + split_tokens = tokenizer.split_tokens_on_unicode + else: + split_tokens = tokenizer.split_tokens_on_spaces + + words, word_tokens = split_tokens(tokens[1:].tolist()) + + token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) + token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources) + + time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) + start_times = time_offset + jump_times[word_boundaries[:-1]] + end_times = time_offset + jump_times[word_boundaries[1:]] + + for segment in segments: + segment["words"] = [] + + for i, (word, start, end) in enumerate(zip(words, start_times, end_times)): + if word.startswith("<|") or word.strip() in ".,!?、。": + continue + + segment = segments[token_sources[word_boundaries[i]]] + segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2))) + + # adjust the segment-level timestamps based on the word-level timestamps + for segment in segments: + if len(segment["words"]) > 0: + segment["start"] = segment["words"][0]["start"] + segment["end"] = segment["words"][-1]["end"] + + def transcribe( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], @@ -27,6 +116,7 @@ def transcribe( logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, + word_level_timestamps: bool = False, **decode_options, ): """ @@ -90,14 +180,14 @@ def transcribe( else: if verbose: print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") - segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(segment) + mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) + _, probs = model.detect_language(mel_segment) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") - language = decode_options["language"] - task = decode_options.get("task", "transcribe") + language: str = decode_options["language"] + task: str = decode_options.get("task", "transcribe") tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: @@ -147,7 +237,8 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: def add_segment( *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult ): - text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) + text_tokens = [token for token in text_tokens.tolist() if token < tokenizer.eot] + text = tokenizer.decode(text_tokens) if len(text.strip()) == 0: # skip empty text output return @@ -158,32 +249,28 @@ def add_segment( "start": start, "end": end, "text": text, - "tokens": text_tokens.tolist(), + "tokens": text_tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob, } ) - if verbose: - line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n" - # compared to just `print(line)`, this replaces any character not representable using - # the system default encoding with an '?', avoiding UnicodeEncodeError. - sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace")) - sys.stdout.flush() # show the progress bar when verbose is False (otherwise the transcribed text will be printed) num_frames = mel.shape[-1] - previous_seek_value = seek + previous_seek = seek with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: while seek < num_frames: - timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) - segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE + time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + mel_segment = mel[:, seek:] + segment_size = min(mel_segment.shape[-1], N_FRAMES) + segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE + mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(segment) + result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) if no_speech_threshold is not None: @@ -194,9 +281,10 @@ def add_segment( should_skip = False if should_skip: - seek += segment.shape[-1] # fast-forward to the next segment boundary + seek += segment_size # fast-forward to the next segment boundary continue + last_segment_index = len(all_segments) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens @@ -210,8 +298,8 @@ def add_segment( sliced_tokens[-1].item() - tokenizer.timestamp_begin ) add_segment( - start=timestamp_offset + start_timestamp_position * time_precision, - end=timestamp_offset + end_timestamp_position * time_precision, + start=time_offset + start_timestamp_position * time_precision, + end=time_offset + end_timestamp_position * time_precision, text_tokens=sliced_tokens[1:-1], result=result, ) @@ -231,22 +319,45 @@ def add_segment( duration = last_timestamp_position * time_precision add_segment( - start=timestamp_offset, - end=timestamp_offset + duration, + start=time_offset, + end=time_offset + duration, text_tokens=tokens, result=result, ) - seek += segment.shape[-1] + seek += segment_size all_tokens.extend(tokens.tolist()) if not condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used prompt_reset_since = len(all_tokens) + if word_level_timestamps: + current_segments = all_segments[last_segment_index:] + add_word_timestamps( + model, + tokenizer, + mel=mel_segment, + num_frames=segment_size, + segments=current_segments, + ) + word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] + if len(word_end_timestamps) > 0: + seek_shift = (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND + seek = previous_seek + round(seek_shift) + + if verbose: + for segment in all_segments[last_segment_index:]: + start, end, text = segment["start"], segment["end"], segment["text"] + line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n" + # compared to just `print(line)`, this replaces any character not representable using + # the system default encoding with an '?', avoiding UnicodeEncodeError. + sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace")) + sys.stdout.flush() + # update progress bar - pbar.update(min(num_frames, seek) - previous_seek_value) - previous_seek_value = seek + pbar.update(min(num_frames, seek) - previous_seek) + previous_seek = seek return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) @@ -280,6 +391,7 @@ def cli(): parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") + parser.add_argument("--word_level_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") args = parser.parse_args().__dict__ From 46ea501da224c5e6b454e33dca99cc82326805b2 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Sat, 21 Jan 2023 01:15:13 -0800 Subject: [PATCH 02/18] moving to `timing.py` --- whisper/timing.py | 109 ++++++++++++++++++++++++++++++++++++++++ whisper/transcribe.py | 112 ++++++------------------------------------ 2 files changed, 124 insertions(+), 97 deletions(-) create mode 100644 whisper/timing.py diff --git a/whisper/timing.py b/whisper/timing.py new file mode 100644 index 00000000..768d5b78 --- /dev/null +++ b/whisper/timing.py @@ -0,0 +1,109 @@ +from typing import List, TYPE_CHECKING + +import numpy as np +import torch +import torch.nn.functional as F + +from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND +from .tokenizer import Tokenizer + +if TYPE_CHECKING: + from .model import Whisper + + +def median_filter(x: torch.Tensor, filter_width: int): + """Apply a median filter of width `filter_width` along the last dimension of `x`""" + assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors" + assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" + + padded = F.pad(x, (0, 0, filter_width // 2, filter_width // 2), mode='replicate') + slices = padded.unfold(-1, filter_width, 1) + return slices.median(dim=-1).values + + +def add_word_timestamps( + model: "Whisper", + tokenizer: Tokenizer, + mel: torch.Tensor, + num_frames: int, + segments: List[dict], + *, + medfilt_width: int = 7, + qk_scale: float = 1.0, +): + if len(segments) == 0: + return + + from dtw import dtw + + # install hooks on the cross attention layers to retrieve the attention weights + QKs = [None] * model.dims.n_text_layer + hooks = [ + block.cross_attn.register_forward_hook( + lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1]) + ) + for i, block in enumerate(model.decoder.blocks) + ] + + tokens = torch.tensor( + [ + *tokenizer.sot_sequence, + tokenizer.timestamp_begin, + *[t for segment in segments for t in segment["tokens"]], + tokenizer.timestamp_begin + mel.shape[-1] // 2, + tokenizer.eot, + ] + ).to(model.device) + + with torch.no_grad(): + model(mel.unsqueeze(0), tokens.unsqueeze(0)) + + for hook in hooks: + hook.remove() + + weights = torch.cat(QKs) # layers * heads * tokens * frames + weights = weights[:, :, :, : num_frames // 2] + weights = median_filter(weights, medfilt_width) + weights = (weights * qk_scale).softmax(dim=-1) + + w = weights / weights.norm(dim=-2, keepdim=True) + matrix = w.mean(axis=(0, 1)).neg().double().cpu().numpy() + + alignment = dtw(matrix) + + jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool) + jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND + + if tokenizer.language in {"zh", "ja", "th", "lo", "my"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + split_tokens = tokenizer.split_tokens_on_unicode + else: + split_tokens = tokenizer.split_tokens_on_spaces + + words, word_tokens = split_tokens(tokens[1:].tolist()) + + token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) + token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources) + + time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) + start_times = time_offset + jump_times[word_boundaries[:-1]] + end_times = time_offset + jump_times[word_boundaries[1:]] + + for segment in segments: + segment["words"] = [] + + for i, (word, start, end) in enumerate(zip(words, start_times, end_times)): + if word.startswith("<|") or word.strip() in ".,!?、。": + continue + + segment = segments[token_sources[word_boundaries[i]]] + segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2))) + + # adjust the segment-level timestamps based on the word-level timestamps + for segment in segments: + if len(segment["words"]) > 0: + segment["start"] = segment["words"][0]["start"] + segment["end"] = segment["words"][-1]["end"] diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 98d30411..4f01210d 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,110 +2,24 @@ import os import sys import warnings -from typing import List, Optional, Tuple, Union, TYPE_CHECKING +from typing import Optional, Tuple, Union, TYPE_CHECKING import numpy as np import torch import tqdm -from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim +from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, \ + pad_or_trim from .decoding import DecodingOptions, DecodingResult -from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer -from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt +from .timing import add_word_timestamps +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer +from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, \ + write_vtt, write_srt if TYPE_CHECKING: from .model import Whisper -def add_word_timestamps( - model: "Whisper", - tokenizer: Tokenizer, - mel: torch.Tensor, - num_frames: int, - segments: List[dict], - *, - medfilt_width: int = 7, - qk_scale: float = 1.0, -): - if len(segments) == 0: - return - - from dtw import dtw - from scipy.ndimage import median_filter - - # install hooks on the cross attention layers to retrieve the attention weights - QKs = [None] * model.dims.n_text_layer - hooks = [ - block.cross_attn.register_forward_hook( - lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1]) - ) - for i, block in enumerate(model.decoder.blocks) - ] - - tokens = torch.tensor( - [ - *tokenizer.sot_sequence, - tokenizer.timestamp_begin, - *[t for segment in segments for t in segment["tokens"]], - tokenizer.timestamp_begin + mel.shape[-1] // 2, - tokenizer.eot, - ] - ).to(model.device) - - with torch.no_grad(): - model(mel.unsqueeze(0), tokens.unsqueeze(0)) - - for hook in hooks: - hook.remove() - - weights = torch.cat(QKs) # layers * heads * tokens * frames - weights = weights[:, :, :, : num_frames // 2].cpu() - weights = median_filter(weights, (1, 1, 1, medfilt_width)) - weights = torch.tensor(weights * qk_scale).softmax(dim=-1) - - w = weights / weights.norm(dim=-2, keepdim=True) - matrix = w.mean(axis=(0, 1)) - - alignment = dtw(-matrix.double().numpy()) - - jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool) - jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND - - if tokenizer.language in {"zh", "ja", "th", "lo", "my"}: - # These languages don't typically use spaces, so it is difficult to split words - # without morpheme analysis. Here, we instead split words at any - # position where the tokens are decoded as valid unicode points - split_tokens = tokenizer.split_tokens_on_unicode - else: - split_tokens = tokenizer.split_tokens_on_spaces - - words, word_tokens = split_tokens(tokens[1:].tolist()) - - token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) - token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources) - - time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE - word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) - start_times = time_offset + jump_times[word_boundaries[:-1]] - end_times = time_offset + jump_times[word_boundaries[1:]] - - for segment in segments: - segment["words"] = [] - - for i, (word, start, end) in enumerate(zip(words, start_times, end_times)): - if word.startswith("<|") or word.strip() in ".,!?、。": - continue - - segment = segments[token_sources[word_boundaries[i]]] - segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2))) - - # adjust the segment-level timestamps based on the word-level timestamps - for segment in segments: - if len(segment["words"]) > 0: - segment["start"] = segment["words"][0]["start"] - segment["end"] = segment["words"][-1]["end"] - - def transcribe( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], @@ -116,7 +30,7 @@ def transcribe( logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, - word_level_timestamps: bool = False, + word_timestamps: bool = False, **decode_options, ): """ @@ -153,6 +67,10 @@ def transcribe( disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + word_timestamps: bool + Extract word-level timestamps using the cross-attention pattern and dynamic time warping, + and include the timestamps for each word in each segment. + decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -332,7 +250,7 @@ def add_segment( # do not feed the prompt tokens if a high temperature was used prompt_reset_since = len(all_tokens) - if word_level_timestamps: + if word_timestamps: current_segments = all_segments[last_segment_index:] add_word_timestamps( model, @@ -342,7 +260,7 @@ def add_segment( segments=current_segments, ) word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] - if len(word_end_timestamps) > 0: + if len(consecutive) > 0 and len(word_end_timestamps) > 0: seek_shift = (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND seek = previous_seek + round(seek_shift) @@ -391,7 +309,7 @@ def cli(): parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") - parser.add_argument("--word_level_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them") + parser.add_argument("--word_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") args = parser.parse_args().__dict__ From 742d2f4c88df1e2b64b2e68e742a6427a4000b4a Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Sat, 21 Jan 2023 16:07:37 -0800 Subject: [PATCH 03/18] numba implementation for dtw, replacing dtw-python --- requirements.txt | 2 +- whisper/timing.py | 54 +++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3fe06aae..45ecbe76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ +numba numpy torch tqdm more-itertools transformers>=4.19.0 ffmpeg-python==0.2.0 -dtw-python==1.3.0 diff --git a/whisper/timing.py b/whisper/timing.py index 768d5b78..66fa2ac8 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -1,5 +1,6 @@ from typing import List, TYPE_CHECKING +import numba import numpy as np import torch import torch.nn.functional as F @@ -21,6 +22,48 @@ def median_filter(x: torch.Tensor, filter_width: int): return slices.median(dim=-1).values +@numba.jit(nopython=True, parallel=True) +def dtw(x: np.ndarray): + N, M = x.shape + cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf + trace = -np.ones((N + 1, M + 1), dtype=np.float32) + + i, j = 0, 0 + cost[0, 0] = 0 + for j in range(1, M + 1): + for i in range(1, N + 1): + c0 = cost[i - 1, j - 1] + c1 = cost[i - 1, j] + c2 = cost[i, j - 1] + + if c0 < c1 and c0 < c2: + c, t = c0, 0 + elif c1 < c0 and c1 < c2: + c, t = c1, 1 + else: + c, t = c2, 2 + + cost[i, j] = x[i - 1, j - 1] + c + trace[i, j] = t + + result = [] + while i > 0 and j > 0: + result.append((i - 1, j - 1)) + + if trace[i, j] == 0: + i -= 1 + j -= 1 + elif trace[i, j] == 1: + i -= 1 + elif trace[i, j] == 2: + j -= 1 + else: + raise ValueError("Unexpected P[i, j]") + + result = np.array(result) + return result[::-1, :].T + + def add_word_timestamps( model: "Whisper", tokenizer: Tokenizer, @@ -34,8 +77,6 @@ def add_word_timestamps( if len(segments) == 0: return - from dtw import dtw - # install hooks on the cross attention layers to retrieve the attention weights QKs = [None] * model.dims.n_text_layer hooks = [ @@ -67,12 +108,11 @@ def add_word_timestamps( weights = (weights * qk_scale).softmax(dim=-1) w = weights / weights.norm(dim=-2, keepdim=True) - matrix = w.mean(axis=(0, 1)).neg().double().cpu().numpy() - - alignment = dtw(matrix) + matrix = w.mean(axis=(0, 1)).neg().cpu().numpy() + text_indices, time_indices = dtw(matrix) - jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool) - jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] / TOKENS_PER_SECOND if tokenizer.language in {"zh", "ja", "th", "lo", "my"}: # These languages don't typically use spaces, so it is difficult to split words From 80331c0c67535193ce7b4d6907c726b6456b66c1 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 23 Jan 2023 11:40:37 -0800 Subject: [PATCH 04/18] triton implementation for dtw --- setup.py | 7 +++- whisper/timing.py | 87 +++++++++++++++++++++++++++++++++---------- whisper/triton_ops.py | 33 ++++++++++++++++ 3 files changed, 106 insertions(+), 21 deletions(-) create mode 100644 whisper/triton_ops.py diff --git a/setup.py b/setup.py index 0e822ab9..f8fe358f 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os +import sys import pkg_resources from setuptools import setup, find_packages @@ -9,6 +10,10 @@ def read_version(fname="whisper/version.py"): return locals()["__version__"] +requirements = [] +if sys.platform.startswith("linux"): + requirements.append("triton>=2.0.0.dev20221202") + setup( name="openai-whisper", py_modules=["whisper"], @@ -22,7 +27,7 @@ def read_version(fname="whisper/version.py"): url="https://github.com/openai/whisper", license="MIT", packages=find_packages(exclude=["tests*"]), - install_requires=[ + install_requires=requirements + [ str(r) for r in pkg_resources.parse_requirements( open(os.path.join(os.path.dirname(__file__), "requirements.txt")) diff --git a/whisper/timing.py b/whisper/timing.py index 66fa2ac8..2f4dbd17 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -22,13 +22,37 @@ def median_filter(x: torch.Tensor, filter_width: int): return slices.median(dim=-1).values +@numba.jit +def backtrace(trace: np.ndarray): + i = trace.shape[0] - 1 + j = trace.shape[1] - 1 + trace[0, :] = 2 + trace[:, 0] = 1 + + result = [] + while i > 0 or j > 0: + result.append((i - 1, j - 1)) + + if trace[i, j] == 0: + i -= 1 + j -= 1 + elif trace[i, j] == 1: + i -= 1 + elif trace[i, j] == 2: + j -= 1 + else: + raise ValueError("Unexpected trace[i, j]") + + result = np.array(result) + return result[::-1, :].T + + @numba.jit(nopython=True, parallel=True) -def dtw(x: np.ndarray): +def dtw_cpu(x: np.ndarray): N, M = x.shape cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf trace = -np.ones((N + 1, M + 1), dtype=np.float32) - i, j = 0, 0 cost[0, 0] = 0 for j in range(1, M + 1): for i in range(1, N + 1): @@ -46,22 +70,45 @@ def dtw(x: np.ndarray): cost[i, j] = x[i - 1, j - 1] + c trace[i, j] = t - result = [] - while i > 0 and j > 0: - result.append((i - 1, j - 1)) + return backtrace(trace) - if trace[i, j] == 0: - i -= 1 - j -= 1 - elif trace[i, j] == 1: - i -= 1 - elif trace[i, j] == 2: - j -= 1 - else: - raise ValueError("Unexpected P[i, j]") - result = np.array(result) - return result[::-1, :].T +def dtw_cuda(x, BLOCK_SIZE=1024): + from .triton_ops import dtw_kernel + + M, N = x.shape + # assert M < N, f"{M=} should be smaller than {N=}" + assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" + + x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) + x_skew = x_skew.T.contiguous() + cost = torch.ones(N + M + 2, M + 2) * np.inf + cost[0, 0] = 0 + cost = cost.cuda() + trace = torch.zeros_like(cost, dtype=torch.int32) + + dtw_kernel[(1,)]( + cost, + trace, + x_skew, + x_skew.stride(0), + cost.stride(0), + trace.stride(0), + N, + M, + BLOCK_SIZE=BLOCK_SIZE + ) + + trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1] + + return backtrace(trace.cpu().numpy()) + + +def dtw(x: torch.Tensor) -> np.ndarray: + if x.is_cuda: + return dtw_cuda(x) + + return dtw_cpu(x.double().cpu().numpy()) def add_word_timestamps( @@ -102,13 +149,13 @@ def add_word_timestamps( for hook in hooks: hook.remove() - weights = torch.cat(QKs) # layers * heads * tokens * frames + weights = torch.cat(QKs[-6:]) # layers * heads * tokens * frames weights = weights[:, :, :, : num_frames // 2] weights = median_filter(weights, medfilt_width) weights = (weights * qk_scale).softmax(dim=-1) + weights = weights / weights.norm(dim=-2, keepdim=True) + matrix = weights.mean(axis=(0, 1)).neg() - w = weights / weights.norm(dim=-2, keepdim=True) - matrix = w.mean(axis=(0, 1)).neg().cpu().numpy() text_indices, time_indices = dtw(matrix) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) @@ -136,7 +183,7 @@ def add_word_timestamps( segment["words"] = [] for i, (word, start, end) in enumerate(zip(words, start_times, end_times)): - if word.startswith("<|") or word.strip() in ".,!?、。": + if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand continue segment = segments[token_sources[word_boundaries[i]]] diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py new file mode 100644 index 00000000..0a9b1380 --- /dev/null +++ b/whisper/triton_ops.py @@ -0,0 +1,33 @@ +try: + import triton + import triton.language as tl +except ImportError: + raise RuntimeError("triton import failed; try `pip install triton`") + + +@triton.jit +def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < M + + for k in range(1, N + M + 1): # k = i + j + tl.debug_barrier() + + p0 = cost + (k - 1) * cost_stride + p1 = cost + k * cost_stride + p2 = cost + k * cost_stride + 1 + + c0 = tl.load(p0 + offsets, mask=mask) + c1 = tl.load(p1 + offsets, mask=mask) + c2 = tl.load(p2 + offsets, mask=mask) + + x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0) + cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2) + + cost_ptr = cost + (k + 1) * cost_stride + 1 + tl.store(cost_ptr + offsets, cost_row, mask=mask) + + trace_ptr = trace + (k + 1) * trace_stride + 1 + tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) + tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) + tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) From 1d2ed66d8db2241fcf7a7fb2291f65bcd4edb4d0 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 23 Jan 2023 14:16:14 -0800 Subject: [PATCH 05/18] add test for dtw implementations --- .github/workflows/test.yml | 2 +- tests/test_timing.py | 57 ++++++++++++++++++++++++++++++++++++++ whisper/timing.py | 2 -- whisper/triton_ops.py | 6 ++-- 4 files changed, 61 insertions(+), 6 deletions(-) create mode 100644 tests/test_timing.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c5b4eecc..14a1e85f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,4 +23,4 @@ jobs: - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: pip install pytest - run: pip install . - - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' + - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' diff --git a/tests/test_timing.py b/tests/test_timing.py new file mode 100644 index 00000000..e10ceecc --- /dev/null +++ b/tests/test_timing.py @@ -0,0 +1,57 @@ +import pytest +import numpy as np +import torch + +from whisper.timing import dtw_cpu, dtw_cuda + + +sizes = [ + (10, 20), (32, 16), (123, 1500), (234, 189) +] + + +@pytest.mark.parametrize("N, M", sizes) +def test_dtw(N: int, M: int): + np.random.seed(42) + steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)]) + np.random.shuffle(steps) + x = np.random.random((N, M)).astype(np.float32) + + i, j, k = 0, 0, 0 + trace = [] + while True: + x[i, j] -= 1 + trace.append((i, j)) + + if k == len(steps): + break + + if k + 1 < len(steps) and steps[k] != steps[k + 1]: + i += 1 + j += 1 + k += 2 + continue + + if steps[k] == 0: + i += 1 + if steps[k] == 1: + j += 1 + k += 1 + + trace = np.array(trace).T + dtw_trace = dtw_cpu(x) + + assert np.allclose(trace, dtw_trace) + + +@pytest.mark.requires_cuda +@pytest.mark.parametrize("N, M", sizes) +def test_dtw_cuda_equivalence(N: int, M: int): + np.random.seed(42) + x_numpy = np.random.randn(N, M).astype(np.float32) + x_cuda = torch.from_numpy(x_numpy).cuda() + + trace_cpu = dtw_cpu(x_numpy) + trace_cuda = dtw_cuda(x_cuda) + + assert np.allclose(trace_cpu, trace_cuda) diff --git a/whisper/timing.py b/whisper/timing.py index 2f4dbd17..1eed7c26 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -77,7 +77,6 @@ def dtw_cuda(x, BLOCK_SIZE=1024): from .triton_ops import dtw_kernel M, N = x.shape - # assert M < N, f"{M=} should be smaller than {N=}" assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) @@ -100,7 +99,6 @@ def dtw_cuda(x, BLOCK_SIZE=1024): ) trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1] - return backtrace(trace.cpu().numpy()) diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index 0a9b1380..25d8b1ac 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -2,7 +2,7 @@ import triton import triton.language as tl except ImportError: - raise RuntimeError("triton import failed; try `pip install triton`") + raise RuntimeError("triton import failed; try `pip install --pre triton`") @triton.jit @@ -28,6 +28,6 @@ def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_ tl.store(cost_ptr + offsets, cost_row, mask=mask) trace_ptr = trace + (k + 1) * trace_stride + 1 - tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) - tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) + tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) + tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) From b61e8f4fd1b912b8d13ec13800bbf80d73905894 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 24 Jan 2023 09:30:12 -0800 Subject: [PATCH 06/18] triton implementation of median_filter --- tests/conftest.py | 10 ++++++++ tests/test_timing.py | 33 +++++++++++++++++++++--- whisper/timing.py | 10 +++++--- whisper/triton_ops.py | 59 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 7 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..920632b0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import random as rand + +import numpy +import pytest + + +@pytest.fixture +def random(): + rand.seed(42) + numpy.random.seed(42) diff --git a/tests/test_timing.py b/tests/test_timing.py index e10ceecc..732f08d6 100644 --- a/tests/test_timing.py +++ b/tests/test_timing.py @@ -1,18 +1,21 @@ import pytest import numpy as np +import scipy.ndimage import torch -from whisper.timing import dtw_cpu, dtw_cuda +from whisper.timing import dtw_cpu, dtw_cuda, median_filter sizes = [ - (10, 20), (32, 16), (123, 1500), (234, 189) + (10, 20), (32, 16), (123, 1500), (234, 189), +] +shapes = [ + (4, 5, 20, 345), (6, 12, 240, 512), ] @pytest.mark.parametrize("N, M", sizes) def test_dtw(N: int, M: int): - np.random.seed(42) steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)]) np.random.shuffle(steps) x = np.random.random((N, M)).astype(np.float32) @@ -47,7 +50,6 @@ def test_dtw(N: int, M: int): @pytest.mark.requires_cuda @pytest.mark.parametrize("N, M", sizes) def test_dtw_cuda_equivalence(N: int, M: int): - np.random.seed(42) x_numpy = np.random.randn(N, M).astype(np.float32) x_cuda = torch.from_numpy(x_numpy).cuda() @@ -55,3 +57,26 @@ def test_dtw_cuda_equivalence(N: int, M: int): trace_cuda = dtw_cuda(x_cuda) assert np.allclose(trace_cpu, trace_cuda) + + +@pytest.mark.parametrize("shape", shapes) +def test_median_filter(shape): + x = torch.randn(*shape) + + for filter_width in [3, 5, 7, 13]: + filtered = median_filter(x, filter_width) + scipy_filtered = scipy.ndimage.median_filter(x, (1, 1, 1, filter_width), mode="nearest") + + assert np.allclose(filtered, scipy_filtered) + + +@pytest.mark.requires_cuda +@pytest.mark.parametrize("shape", shapes) +def test_median_filter_equivalence(shape): + x = torch.randn(*shape) + + for filter_width in [3, 5, 7, 13]: + filtered_cpu = median_filter(x, filter_width) + filtered_gpu = median_filter(x.cuda(), filter_width).cpu() + + assert np.allclose(filtered_cpu, filtered_gpu) diff --git a/whisper/timing.py b/whisper/timing.py index 1eed7c26..f52495b1 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -17,9 +17,13 @@ def median_filter(x: torch.Tensor, filter_width: int): assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors" assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" - padded = F.pad(x, (0, 0, filter_width // 2, filter_width // 2), mode='replicate') - slices = padded.unfold(-1, filter_width, 1) - return slices.median(dim=-1).values + x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode='replicate') + if x.is_cuda: + from .triton_ops import median_filter_cuda + return median_filter_cuda(x, filter_width) + + # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) + return x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] @numba.jit diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index 25d8b1ac..d829e204 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -1,3 +1,9 @@ +import math + +import numpy as np +import torch +from functools import lru_cache + try: import triton import triton.language as tl @@ -31,3 +37,56 @@ def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_ tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) + + +@lru_cache(maxsize=None) +def median_kernel(filter_width: int): + @triton.jit + def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width + row_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < y_stride + + x_ptr = x + row_idx * x_stride + y_ptr = y + row_idx * y_stride + + LOAD_ALL_ROWS_HERE + + BUBBLESORT_HERE + + tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) + + kernel = triton.JITFunction(kernel.fn) + kernel.src = kernel.src.replace(" LOAD_ALL_ROWS_HERE", "\n".join([ + f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" + for i in range(filter_width) + ])) + kernel.src = kernel.src.replace(" BUBBLESORT_HERE", "\n\n".join([ + "\n\n".join([ + "\n".join([ + f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", + f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", + f" row{j} = smaller", + f" row{j + 1} = larger", + ]) + for j in range(filter_width - i - 1) + ]) + for i in range(filter_width // 2 + 1) + ])) + kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") + + return kernel + + +def median_filter_cuda(x: torch.Tensor, filter_width: int): + """Apply a median filter of given width along the last dimension of x""" + slices = x.contiguous().unfold(-1, filter_width, 1) + grid = np.prod(slices.shape[:-2]) + + kernel = median_kernel(filter_width) + y = torch.empty_like(slices[..., 0]) + + BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length() + kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE) + + return y From 54f2901a722ec5e31d26aa20313a69175bbaa48d Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 24 Jan 2023 09:30:32 -0800 Subject: [PATCH 07/18] a simple word-level timestamps test --- tests/test_transcribe.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index f5d66c37..bb0ee634 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -13,10 +13,21 @@ def test_transcribe(model_name: str): audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") language = "en" if model_name.endswith(".en") else None - result = model.transcribe(audio_path, language=language, temperature=0.0) + result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True) assert result["language"] == "en" transcription = result["text"].lower() assert "my fellow americans" in transcription assert "your country" in transcription assert "do for you" in transcription + + timing_checked = False + for segment in result["segments"]: + for timing in segment["words"]: + assert timing["start"] < timing["end"] + if timing["word"].strip() == "Americans": + assert timing["start"] <= 1.75 + assert timing["end"] >= 2.05 + timing_checked = True + + assert timing_checked From 8ce627736259b0a47dc04672dbefaae254c702a5 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 24 Jan 2023 09:34:46 -0800 Subject: [PATCH 08/18] add scipy as dev dependency --- .github/workflows/test.yml | 3 +-- setup.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 14a1e85f..f06bff79 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,5 @@ jobs: - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch - uses: actions/checkout@v2 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - - run: pip install pytest - - run: pip install . + - run: pip install .["dev"] - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' diff --git a/setup.py b/setup.py index f8fe358f..af1c79fd 100644 --- a/setup.py +++ b/setup.py @@ -37,5 +37,5 @@ def read_version(fname="whisper/version.py"): "console_scripts": ["whisper=whisper.transcribe:cli"], }, include_package_data=True, - extras_require={"dev": ["pytest"]}, + extras_require={"dev": ["pytest", "scipy"]}, ) From cd5191fdcba8d64a95a2cbd2975f11c9cae31744 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 24 Jan 2023 11:06:34 -0800 Subject: [PATCH 09/18] installs an older version of Triton if CUDA < 11.4 --- setup.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index af1c79fd..dcdb9312 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,18 @@ def read_version(fname="whisper/version.py"): requirements = [] if sys.platform.startswith("linux"): - requirements.append("triton>=2.0.0.dev20221202") + triton_requirement = "triton>=2.0.0.dev20221202" + try: + import re + import subprocess + version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1] + major, minor = re.findall(rb"cuda_([\d]+)\.([\d]+)", version_line)[0] + if (int(major), int(minor)) < (11, 4): + # the last version supporting CUDA < 11.4 + triton_requirement = "triton==2.0.0.dev20221011" + except (IndexError, OSError, subprocess.SubprocessError): + pass + requirements.append(triton_requirement) setup( name="openai-whisper", From d4f9399506bbff33df4ef0b63faf3d30e0b855ca Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 24 Jan 2023 14:58:51 -0800 Subject: [PATCH 10/18] fix broken merge --- whisper/transcribe.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8ed5d6e2..7a605b40 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -27,8 +27,8 @@ def transcribe( logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, - word_timestamps: bool = False, initial_prompt: Optional[str] = None, + word_timestamps: bool = False, **decode_options, ): """ @@ -69,6 +69,11 @@ def transcribe( Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment. + initial_prompt: Optional[str] + Optional text to provide as a prompt for the first window. This can be used to provide, or + "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns + to make it more likely to predict those word correctly. + decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -174,9 +179,6 @@ def add_segment( } ) - if verbose: - print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")) - # show the progress bar when verbose is False (otherwise the transcribed text will be printed) num_frames = mel.shape[-1] previous_seek = seek @@ -269,11 +271,8 @@ def add_segment( if verbose: for segment in all_segments[last_segment_index:]: start, end, text = segment["start"], segment["end"], segment["text"] - line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n" - # compared to just `print(line)`, this replaces any character not representable using - # the system default encoding with an '?', avoiding UnicodeEncodeError. - sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace")) - sys.stdout.flush() + line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" + print(make_safe(line)) # update progress bar pbar.update(min(num_frames, seek) - previous_seek) From 8e2756bb1f74ed66b7c93b5390afc4e09ad5397b Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 24 Jan 2023 21:18:39 -0800 Subject: [PATCH 11/18] loosen nvcc version match regex --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index dcdb9312..a548c8d3 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ def read_version(fname="whisper/version.py"): import re import subprocess version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1] - major, minor = re.findall(rb"cuda_([\d]+)\.([\d]+)", version_line)[0] + major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0] if (int(major), int(minor)) < (11, 4): # the last version supporting CUDA < 11.4 triton_requirement = "triton==2.0.0.dev20221011" From 6c431c41b1ac57161060f23d7fd4438690674d22 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Wed, 25 Jan 2023 00:19:25 -0800 Subject: [PATCH 12/18] find_alignment() function --- whisper/timing.py | 77 ++++++++++++++++++++++++++++--------------- whisper/transcribe.py | 9 +++-- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/whisper/timing.py b/whisper/timing.py index f52495b1..410cf577 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -1,5 +1,5 @@ from typing import List, TYPE_CHECKING - +from dataclasses import dataclass import numba import numpy as np import torch @@ -113,18 +113,34 @@ def dtw(x: torch.Tensor) -> np.ndarray: return dtw_cpu(x.double().cpu().numpy()) -def add_word_timestamps( +@dataclass +class Alignment: + words: List[str] + word_tokens: List[List[int]] + start_times: np.ndarray + end_times: np.ndarray + + +def find_alignment( model: "Whisper", tokenizer: Tokenizer, + text_tokens: List[int], mel: torch.Tensor, num_frames: int, - segments: List[dict], *, + max_qk_layers: int = 6, medfilt_width: int = 7, qk_scale: float = 1.0, -): - if len(segments) == 0: - return +) -> Alignment: + tokens = torch.tensor( + [ + *tokenizer.sot_sequence, + tokenizer.timestamp_begin, + *text_tokens, + tokenizer.timestamp_begin + num_frames // 2, + tokenizer.eot, + ] + ).to(model.device) # install hooks on the cross attention layers to retrieve the attention weights QKs = [None] * model.dims.n_text_layer @@ -135,23 +151,13 @@ def add_word_timestamps( for i, block in enumerate(model.decoder.blocks) ] - tokens = torch.tensor( - [ - *tokenizer.sot_sequence, - tokenizer.timestamp_begin, - *[t for segment in segments for t in segment["tokens"]], - tokenizer.timestamp_begin + mel.shape[-1] // 2, - tokenizer.eot, - ] - ).to(model.device) - with torch.no_grad(): model(mel.unsqueeze(0), tokens.unsqueeze(0)) for hook in hooks: hook.remove() - weights = torch.cat(QKs[-6:]) # layers * heads * tokens * frames + weights = torch.cat(QKs[-max_qk_layers:]) # layers * heads * tokens * frames weights = weights[:, :, :, : num_frames // 2] weights = median_filter(weights, medfilt_width) weights = (weights * qk_scale).softmax(dim=-1) @@ -167,24 +173,43 @@ def add_word_timestamps( # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points - split_tokens = tokenizer.split_tokens_on_unicode + words, word_tokens = tokenizer.split_tokens_on_unicode(tokens[1:].tolist()) else: - split_tokens = tokenizer.split_tokens_on_spaces + words, word_tokens = tokenizer.split_tokens_on_spaces(tokens[1:].tolist()) - words, word_tokens = split_tokens(tokens[1:].tolist()) + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] - token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) - token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources) + return Alignment(words, word_tokens, start_times, end_times) + + +def add_word_timestamps( + segments: List[dict], + model: "Whisper", + tokenizer: Tokenizer, + mel: torch.Tensor, + num_frames: int, + **hyperparams, +): + if len(segments) == 0: + return + + text_tokens = [t for segment in segments for t in segment["tokens"]] + alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams) time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE - word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) - start_times = time_offset + jump_times[word_boundaries[:-1]] - end_times = time_offset + jump_times[word_boundaries[1:]] + alignment.start_times = time_offset + alignment.start_times + alignment.end_times = time_offset + alignment.end_times + + token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) + token_sources: List[int] = [None] * len(tokenizer.sot_sequence) + list(token_sources) for segment in segments: segment["words"] = [] - for i, (word, start, end) in enumerate(zip(words, start_times, end_times)): + word_boundaries = np.pad(np.cumsum([len(t) for t in alignment.word_tokens]), (1, 0)) + for i, (word, start, end) in enumerate(zip(alignment.words, alignment.start_times, alignment.end_times)): if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand continue diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 7a605b40..d17eb600 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -111,6 +111,9 @@ def transcribe( task: str = decode_options.get("task", "transcribe") tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + if word_timestamps and task == "translate": + warnings.warn("Word-level timestamps on translations may not be reliable.") + def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature decode_result = None @@ -257,11 +260,11 @@ def add_segment( if word_timestamps: current_segments = all_segments[last_segment_index:] add_word_timestamps( - model, - tokenizer, + segments=current_segments, + model=model, + tokenizer=tokenizer, mel=mel_segment, num_frames=segment_size, - segments=current_segments, ) word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] if len(consecutive) > 0 and len(word_end_timestamps) > 0: From 5fa43566f00a3e337f7fb481a2b962118453a96b Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Thu, 2 Feb 2023 01:28:10 -0800 Subject: [PATCH 13/18] miscellaneous improvements --- tests/conftest.py | 4 + tests/test_timing.py | 9 +- tests/test_transcribe.py | 7 +- whisper/__init__.py | 22 +++++ whisper/model.py | 12 ++- whisper/timing.py | 179 ++++++++++++++++++++++++++++----------- whisper/tokenizer.py | 11 ++- whisper/transcribe.py | 83 +++++++++--------- whisper/utils.py | 67 ++++++++++----- 9 files changed, 278 insertions(+), 116 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 920632b0..31f1d6b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,10 @@ import pytest +def pytest_configure(config): + config.addinivalue_line("markers", "requires_cuda") + + @pytest.fixture def random(): rand.seed(42) diff --git a/tests/test_timing.py b/tests/test_timing.py index 732f08d6..50a2583f 100644 --- a/tests/test_timing.py +++ b/tests/test_timing.py @@ -10,7 +10,7 @@ (10, 20), (32, 16), (123, 1500), (234, 189), ] shapes = [ - (4, 5, 20, 345), (6, 12, 240, 512), + (10,), (1, 15), (4, 5, 345), (6, 12, 240, 512), ] @@ -65,7 +65,12 @@ def test_median_filter(shape): for filter_width in [3, 5, 7, 13]: filtered = median_filter(x, filter_width) - scipy_filtered = scipy.ndimage.median_filter(x, (1, 1, 1, filter_width), mode="nearest") + + # using np.pad to reflect-pad, because Scipy's behavior is different near the edges. + pad_width = filter_width // 2 + padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect") + scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width]) + scipy_filtered = scipy_filtered[..., pad_width:-pad_width] assert np.allclose(filtered, scipy_filtered) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index bb0ee634..9802f734 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -25,9 +25,10 @@ def test_transcribe(model_name: str): for segment in result["segments"]: for timing in segment["words"]: assert timing["start"] < timing["end"] - if timing["word"].strip() == "Americans": - assert timing["start"] <= 1.75 - assert timing["end"] >= 2.05 + if timing["word"].strip(" ,") == "Americans": + assert timing["start"] <= 1.8 + assert timing["end"] >= 1.8 + print(timing) timing_checked = True assert timing_checked diff --git a/whisper/__init__.py b/whisper/__init__.py index cb334065..26d1e0ea 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -29,6 +29,23 @@ "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", } +# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are +# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. +_ALIGNMENT_HEADS = { + "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", + "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", + "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", + "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", + "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", + "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', + "large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', +} + + def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) @@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow if name in _MODELS: checkpoint_file = _download(_MODELS[name], download_root, in_memory) + alignment_heads = _ALIGNMENT_HEADS[name] elif os.path.isfile(name): checkpoint_file = open(name, "rb").read() if in_memory else name + alignment_heads = None else: raise RuntimeError(f"Model {name} not found; available models = {available_models()}") @@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow model = Whisper(dims) model.load_state_dict(checkpoint["model_state_dict"]) + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + return model.to(device) diff --git a/whisper/model.py b/whisper/model.py index be73a4a8..dff96107 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from typing import Dict from typing import Iterable, Optional - +import gzip +import base64 import numpy as np import torch import torch.nn.functional as F @@ -213,6 +214,15 @@ def __init__(self, dims: ModelDimensions): self.dims.n_text_head, self.dims.n_text_layer, ) + # use the last half layers for alignment by default; see `set_alignment_heads()` below + all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool) + all_heads[self.dims.n_text_layer // 2:] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + + def set_alignment_heads(self, dump: bytes): + array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() + mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head) + self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) def embed_audio(self, mel: torch.Tensor): return self.encoder(mel) diff --git a/whisper/timing.py b/whisper/timing.py index 410cf577..e056f1e9 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -1,5 +1,9 @@ -from typing import List, TYPE_CHECKING +import time +import subprocess +import warnings from dataclasses import dataclass +from typing import List, TYPE_CHECKING + import numba import numpy as np import torch @@ -14,17 +18,31 @@ def median_filter(x: torch.Tensor, filter_width: int): """Apply a median filter of width `filter_width` along the last dimension of `x`""" - assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors" + if (ndim := x.ndim) <= 2: # `F.pad` does not support 1D or 2D inputs for reflect padding + x = x[None, None, :] + assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" - x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode='replicate') + result = None + x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") if x.is_cuda: - from .triton_ops import median_filter_cuda - return median_filter_cuda(x, filter_width) + try: + from .triton_ops import median_filter_cuda + result = median_filter_cuda(x, filter_width) + except (RuntimeError, subprocess.CalledProcessError): + warnings.warn( + "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " + "falling back to a slower median kernel implementation..." + ) + + if result is None: + # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) + result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] - # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) - return x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] + if ndim <= 2: + result = result[0, 0] + return result @numba.jit def backtrace(trace: np.ndarray): @@ -108,17 +126,24 @@ def dtw_cuda(x, BLOCK_SIZE=1024): def dtw(x: torch.Tensor) -> np.ndarray: if x.is_cuda: - return dtw_cuda(x) + try: + return dtw_cuda(x) + except (RuntimeError, subprocess.CalledProcessError): + warnings.warn( + "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " + "falling back to a slower DTW implementation..." + ) return dtw_cpu(x.double().cpu().numpy()) @dataclass -class Alignment: - words: List[str] - word_tokens: List[List[int]] - start_times: np.ndarray - end_times: np.ndarray +class WordTiming: + word: str + tokens: List[int] + start: float + end: float + probability: float def find_alignment( @@ -128,16 +153,14 @@ def find_alignment( mel: torch.Tensor, num_frames: int, *, - max_qk_layers: int = 6, medfilt_width: int = 7, qk_scale: float = 1.0, -) -> Alignment: +) -> List[WordTiming]: tokens = torch.tensor( [ *tokenizer.sot_sequence, - tokenizer.timestamp_begin, + tokenizer.no_timestamps, *text_tokens, - tokenizer.timestamp_begin + num_frames // 2, tokenizer.eot, ] ).to(model.device) @@ -146,50 +169,105 @@ def find_alignment( QKs = [None] * model.dims.n_text_layer hooks = [ block.cross_attn.register_forward_hook( - lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1]) + lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) ) for i, block in enumerate(model.decoder.blocks) ] with torch.no_grad(): - model(mel.unsqueeze(0), tokens.unsqueeze(0)) + logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] + token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1) + text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist() for hook in hooks: hook.remove() - weights = torch.cat(QKs[-max_qk_layers:]) # layers * heads * tokens * frames - weights = weights[:, :, :, : num_frames // 2] - weights = median_filter(weights, medfilt_width) + # heads * tokens * frames + weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T]) + weights = weights[:, :, : num_frames // 2] weights = (weights * qk_scale).softmax(dim=-1) - weights = weights / weights.norm(dim=-2, keepdim=True) - matrix = weights.mean(axis=(0, 1)).neg() + std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) + weights = (weights - mean) / std + weights = median_filter(weights, medfilt_width) + + matrix = weights.mean(axis=0) + matrix = matrix[len(tokenizer.sot_sequence):-1] + text_indices, time_indices = dtw(-matrix) - text_indices, time_indices = dtw(matrix) + words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) jump_times = time_indices[jumps] / TOKENS_PER_SECOND - - if tokenizer.language in {"zh", "ja", "th", "lo", "my"}: - # These languages don't typically use spaces, so it is difficult to split words - # without morpheme analysis. Here, we instead split words at any - # position where the tokens are decoded as valid unicode points - words, word_tokens = tokenizer.split_tokens_on_unicode(tokens[1:].tolist()) - else: - words, word_tokens = tokenizer.split_tokens_on_spaces(tokens[1:].tolist()) - - word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) start_times = jump_times[word_boundaries[:-1]] end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + # hack: ensure the first and second word is not longer than twice the median word duration. + # a better segmentation algorithm based on VAD should be able to replace this. + word_durations = end_times - start_times + word_durations = word_durations[word_durations.nonzero()] + if len(word_durations) > 0: + median_duration = np.median(word_durations) + max_duration = median_duration * 2 + if len(word_durations) >= 2 and word_durations[1] > max_duration: + end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration) + if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration: + start_times[0] = max(0, end_times[0] - max_duration) + + return [ + WordTiming(word, tokens, start, end, probability) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + - return Alignment(words, word_tokens, start_times, end_times) +def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if previous.word.startswith(" ") and previous.word.strip() in prepended: + # prepend it to the following word + following.word = previous.word + following.word + following.tokens = previous.tokens + following.tokens + previous.word = "" + previous.tokens = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous.word.endswith(" ") and following.word in appended: + # append it to the previous word + previous.word = previous.word + following.word + previous.tokens = previous.tokens + following.tokens + following.word = "" + following.tokens = [] + else: + i = j + j += 1 def add_word_timestamps( + *, segments: List[dict], model: "Whisper", tokenizer: Tokenizer, mel: torch.Tensor, num_frames: int, + prepend_punctuations: str = "\"\'“¿([{-", + append_punctuations: str = "\"\'.。,,!!??::”)]}、", **hyperparams, ): if len(segments) == 0: @@ -197,27 +275,26 @@ def add_word_timestamps( text_tokens = [t for segment in segments for t in segment["tokens"]] alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams) + merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE - alignment.start_times = time_offset + alignment.start_times - alignment.end_times = time_offset + alignment.end_times - token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) - token_sources: List[int] = [None] * len(tokenizer.sot_sequence) + list(token_sources) for segment in segments: segment["words"] = [] - word_boundaries = np.pad(np.cumsum([len(t) for t in alignment.word_tokens]), (1, 0)) - for i, (word, start, end) in enumerate(zip(alignment.words, alignment.start_times, alignment.end_times)): - if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand - continue - - segment = segments[token_sources[word_boundaries[i]]] - segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2))) + word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0)) + for i, timing in enumerate(alignment): + if timing.word: + segment = segments[token_sources[word_boundaries[i]]] + start = round(time_offset + timing.start, 2) + end = round(time_offset + timing.end, 2) + segment["words"].append( + dict(word=timing.word, start=start, end=end, probability=timing.probability) + ) - # adjust the segment-level timestamps based on the word-level timestamps for segment in segments: - if len(segment["words"]) > 0: - segment["start"] = segment["words"][0]["start"] - segment["end"] = segment["words"][-1]["end"] + if len(words := segment["words"]) > 0: + # adjust the segment-level timestamps based on the word-level timestamps + segment["start"] = words[0]["start"] + segment["end"] = words[-1]["end"] diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index 55140f5b..f7b3ab7b 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -258,6 +258,15 @@ def _get_single_token_id(self, text) -> int: assert len(tokens) == 1, f"{text} is not encoded as a single token" return tokens[0] + def split_to_word_tokens(self, tokens: List[int]): + if self.language in {"zh", "ja", "th", "lo", "my"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + + return self.split_tokens_on_spaces(tokens) + def split_tokens_on_unicode(self, tokens: List[int]): words = [] word_tokens = [] @@ -282,7 +291,7 @@ def split_tokens_on_spaces(self, tokens: List[int]): special = subword_tokens[0] >= self.eot with_space = subword.startswith(" ") punctuation = subword.strip() in string.punctuation - if special or with_space or punctuation: + if special or with_space or punctuation or len(words) == 0: words.append(subword) word_tokens.append(subword_tokens) else: diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d17eb600..3674bb9e 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -159,33 +159,25 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: else: initial_prompt_tokens = [] - def add_segment( - *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult + def new_segment( + *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult ): - text_tokens = [token for token in text_tokens.tolist() if token < tokenizer.eot] - text = tokenizer.decode(text_tokens) - if len(text.strip()) == 0: # skip empty text output - return - - all_segments.append( - { - "id": len(all_segments), - "seek": seek, - "start": start, - "end": end, - "text": text, - "tokens": text_tokens, - "temperature": result.temperature, - "avg_logprob": result.avg_logprob, - "compression_ratio": result.compression_ratio, - "no_speech_prob": result.no_speech_prob, - } - ) - - # show the progress bar when verbose is False (otherwise the transcribed text will be printed) + text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot] + return { + "id": len(all_segments), + "seek": seek, + "start": start, + "end": end, + "text": tokenizer.decode(text_tokens), + "tokens": text_tokens, + "temperature": result.temperature, + "avg_logprob": result.avg_logprob, + "compression_ratio": result.compression_ratio, + "no_speech_prob": result.no_speech_prob, + } + + # show the progress bar when verbose is False (if True, transcribed text will be printed) num_frames = mel.shape[-1] - previous_seek = seek - with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: while seek < num_frames: time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) @@ -209,7 +201,10 @@ def add_segment( seek += segment_size # fast-forward to the next segment boundary continue - last_segment_index = len(all_segments) + previous_seek = seek + current_segments = [] + current_tokens = [] + timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens @@ -222,12 +217,13 @@ def add_segment( end_timestamp_position = ( sliced_tokens[-1].item() - tokenizer.timestamp_begin ) - add_segment( + current_segments.append(new_segment( start=time_offset + start_timestamp_position * time_precision, end=time_offset + end_timestamp_position * time_precision, - text_tokens=sliced_tokens[1:-1], + tokens=sliced_tokens, result=result, - ) + )) + current_tokens.append(sliced_tokens.tolist()) last_slice = current_slice last_timestamp_position = ( tokens[last_slice - 1].item() - tokenizer.timestamp_begin @@ -243,22 +239,20 @@ def add_segment( last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin duration = last_timestamp_position * time_precision - add_segment( + current_segments.append(new_segment( start=time_offset, end=time_offset + duration, - text_tokens=tokens, + tokens=tokens, result=result, - ) - + )) + current_tokens.append(tokens.tolist()) seek += segment_size - all_tokens.extend(tokens.tolist()) if not condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used prompt_reset_since = len(all_tokens) if word_timestamps: - current_segments = all_segments[last_segment_index:] add_word_timestamps( segments=current_segments, model=model, @@ -268,18 +262,29 @@ def add_segment( ) word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] if len(consecutive) > 0 and len(word_end_timestamps) > 0: - seek_shift = (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND - seek = previous_seek + round(seek_shift) + seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND) + if seek_shift > 0: + seek = previous_seek + seek_shift if verbose: - for segment in all_segments[last_segment_index:]: + for segment in current_segments: start, end, text = segment["start"], segment["end"], segment["text"] line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" print(make_safe(line)) + # if a segment is instantaneous or does not contain text, clear it + for i, segment in enumerate(current_segments): + if segment["start"] == segment["end"] or segment["text"].strip() == "": + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] + current_tokens[i] = [] + + all_segments.extend(current_segments) + all_tokens.extend([token for segment in current_tokens for token in segment]) + # update progress bar pbar.update(min(num_frames, seek) - previous_seek) - previous_seek = seek return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), diff --git a/whisper/utils.py b/whisper/utils.py index 5dacc173..8ee91293 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -85,34 +85,63 @@ def write_result(self, result: dict, file: TextIO): print(segment['text'].strip(), file=file, flush=True) -class WriteVTT(ResultWriter): +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str + + def iterate_result(self, result: dict): + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment['text'].strip().replace('-->', '->') + + if word_timings := segment.get("words", None): + all_words = [timing["word"] for timing in word_timings] + all_words[0] = all_words[0].strip() # remove the leading space, if any + last = segment_start + for i, this_word in enumerate(word_timings): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, segment_text + + yield start, end, "".join( + [f"{word}" if j == i else word for j, word in enumerate(all_words)] + ) + last = end + + if last != segment_end: + yield last, segment_end, segment_text + else: + yield segment_start, segment_end, segment_text + + def format_timestamp(self, seconds: float): + return format_timestamp( + seconds=seconds, + always_include_hours=self.always_include_hours, + decimal_marker=self.decimal_marker, + ) + + +class WriteVTT(SubtitlesWriter): extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = '.' def write_result(self, result: dict, file: TextIO): print("WEBVTT\n", file=file) - for segment in result["segments"]: - print( - f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) + for start, end, text in self.iterate_result(result): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) -class WriteSRT(ResultWriter): +class WriteSRT(SubtitlesWriter): extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = ',' def write_result(self, result: dict, file: TextIO): - for i, segment in enumerate(result["segments"], start=1): - # write srt lines - print( - f"{i}\n" - f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " - f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) + for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) class WriteTSV(ResultWriter): From 48537aaf1074ae1be17a7e3df558a7ac95073ba7 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Thu, 2 Feb 2023 09:34:55 -0800 Subject: [PATCH 14/18] skip median filtering when the input is too small --- whisper/model.py | 7 ++++--- whisper/timing.py | 9 +++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index dff96107..a1ab2e34 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -1,16 +1,17 @@ +import base64 +import gzip from dataclasses import dataclass from typing import Dict from typing import Iterable, Optional -import gzip -import base64 + import numpy as np import torch import torch.nn.functional as F from torch import Tensor from torch import nn -from .transcribe import transcribe as transcribe_function from .decoding import detect_language as detect_language_function, decode as decode_function +from .transcribe import transcribe as transcribe_function @dataclass diff --git a/whisper/timing.py b/whisper/timing.py index e056f1e9..98927aa0 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -1,4 +1,3 @@ -import time import subprocess import warnings from dataclasses import dataclass @@ -18,7 +17,13 @@ def median_filter(x: torch.Tensor, filter_width: int): """Apply a median filter of width `filter_width` along the last dimension of `x`""" - if (ndim := x.ndim) <= 2: # `F.pad` does not support 1D or 2D inputs for reflect padding + pad_width = filter_width // 2 + if x.shape[-1] <= pad_width: + # F.pad requires the padding width to be smaller than the input dimension + return x + + if (ndim := x.ndim) <= 2: + # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D x = x[None, None, :] assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" From 8eb29c3ef10559910cbee47b1baedefd8388458a Mon Sep 17 00:00:00 2001 From: ryanheise Date: Fri, 17 Feb 2023 06:59:40 +1100 Subject: [PATCH 15/18] Expose punctuation options in cli and transcribe() (#973) --- whisper/transcribe.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 3674bb9e..d63a2b58 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -29,6 +29,8 @@ def transcribe( condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, word_timestamps: bool = False, + prepend_punctuations: str = "\"\'“¿([{-", + append_punctuations: str = "\"\'.。,,!!??::”)]}、", **decode_options, ): """ @@ -69,6 +71,12 @@ def transcribe( Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment. + prepend_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the next word + + append_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the previous word + initial_prompt: Optional[str] Optional text to provide as a prompt for the first window. This can be used to provide, or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns @@ -259,6 +267,8 @@ def new_segment( tokenizer=tokenizer, mel=mel_segment, num_frames=segment_size, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, ) word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] if len(consecutive) > 0 and len(word_end_timestamps) > 0: @@ -324,6 +334,8 @@ def cli(): parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--word_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them") + parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="If word_timestamps is True, merge these punctuation symbols with the next word") + parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="If word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") args = parser.parse_args().__dict__ From 31cd418f273ab65615f25293fa21fb69c04fb412 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 6 Mar 2023 13:13:22 -0800 Subject: [PATCH 16/18] fix merge error --- whisper/transcribe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index c88655c9..6bf41126 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -218,19 +218,21 @@ def new_segment( if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]: consecutive = consecutive.tolist() + [len(tokens)] + last_slice = 0 for current_slice in consecutive: sliced_tokens = tokens[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin current_segments.append(new_segment( - start=timestamp_offset + start_timestamp_pos * time_precision, - end=timestamp_offset + end_timestamp_pos * time_precision, + start=time_offset + start_timestamp_pos * time_precision, + end=time_offset + end_timestamp_pos * time_precision, tokens=sliced_tokens, result=result, )) current_tokens.append(sliced_tokens.tolist()) last_slice = current_slice + if ended_with_single_timestamp: # single timestamp at the end means no speech after the last timestamp. seek += segment.shape[-1] From 145f325d687f3ea9cb0b41fb5d148735d46516b7 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 6 Mar 2023 13:18:57 -0800 Subject: [PATCH 17/18] fix merge error 2 --- whisper/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6bf41126..bd877a69 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -235,7 +235,7 @@ def new_segment( if ended_with_single_timestamp: # single timestamp at the end means no speech after the last timestamp. - seek += segment.shape[-1] + seek += segment_size else: # otherwise, ignore the unfinished segment and seek to the last timestamp last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin From 2b079c41fa83e850c7750605b5394a80974b697d Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 6 Mar 2023 13:31:55 -0800 Subject: [PATCH 18/18] annotating that word_timestamps is experimental --- whisper/transcribe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index bd877a69..d7c04874 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -335,9 +335,9 @@ def cli(): parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") - parser.add_argument("--word_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them") - parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="If word_timestamps is True, merge these punctuation symbols with the next word") - parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="If word_timestamps is True, merge these punctuation symbols with the previous word") + parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") + parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") + parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") args = parser.parse_args().__dict__