diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c5b4eecc..f06bff79 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,5 @@ jobs: - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch - uses: actions/checkout@v2 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - - run: pip install pytest - - run: pip install . - - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' + - run: pip install .["dev"] + - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' diff --git a/requirements.txt b/requirements.txt index a4614035..45ecbe76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +numba numpy torch tqdm diff --git a/setup.py b/setup.py index 0e822ab9..a548c8d3 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os +import sys import pkg_resources from setuptools import setup, find_packages @@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"): return locals()["__version__"] +requirements = [] +if sys.platform.startswith("linux"): + triton_requirement = "triton>=2.0.0.dev20221202" + try: + import re + import subprocess + version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1] + major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0] + if (int(major), int(minor)) < (11, 4): + # the last version supporting CUDA < 11.4 + triton_requirement = "triton==2.0.0.dev20221011" + except (IndexError, OSError, subprocess.SubprocessError): + pass + requirements.append(triton_requirement) + setup( name="openai-whisper", py_modules=["whisper"], @@ -22,7 +38,7 @@ def read_version(fname="whisper/version.py"): url="https://github.com/openai/whisper", license="MIT", packages=find_packages(exclude=["tests*"]), - install_requires=[ + install_requires=requirements + [ str(r) for r in pkg_resources.parse_requirements( open(os.path.join(os.path.dirname(__file__), "requirements.txt")) @@ -32,5 +48,5 @@ def read_version(fname="whisper/version.py"): "console_scripts": ["whisper=whisper.transcribe:cli"], }, include_package_data=True, - extras_require={"dev": ["pytest"]}, + extras_require={"dev": ["pytest", "scipy"]}, ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..31f1d6b4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,14 @@ +import random as rand + +import numpy +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "requires_cuda") + + +@pytest.fixture +def random(): + rand.seed(42) + numpy.random.seed(42) diff --git a/tests/test_timing.py b/tests/test_timing.py new file mode 100644 index 00000000..50a2583f --- /dev/null +++ b/tests/test_timing.py @@ -0,0 +1,87 @@ +import pytest +import numpy as np +import scipy.ndimage +import torch + +from whisper.timing import dtw_cpu, dtw_cuda, median_filter + + +sizes = [ + (10, 20), (32, 16), (123, 1500), (234, 189), +] +shapes = [ + (10,), (1, 15), (4, 5, 345), (6, 12, 240, 512), +] + + +@pytest.mark.parametrize("N, M", sizes) +def test_dtw(N: int, M: int): + steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)]) + np.random.shuffle(steps) + x = np.random.random((N, M)).astype(np.float32) + + i, j, k = 0, 0, 0 + trace = [] + while True: + x[i, j] -= 1 + trace.append((i, j)) + + if k == len(steps): + break + + if k + 1 < len(steps) and steps[k] != steps[k + 1]: + i += 1 + j += 1 + k += 2 + continue + + if steps[k] == 0: + i += 1 + if steps[k] == 1: + j += 1 + k += 1 + + trace = np.array(trace).T + dtw_trace = dtw_cpu(x) + + assert np.allclose(trace, dtw_trace) + + +@pytest.mark.requires_cuda +@pytest.mark.parametrize("N, M", sizes) +def test_dtw_cuda_equivalence(N: int, M: int): + x_numpy = np.random.randn(N, M).astype(np.float32) + x_cuda = torch.from_numpy(x_numpy).cuda() + + trace_cpu = dtw_cpu(x_numpy) + trace_cuda = dtw_cuda(x_cuda) + + assert np.allclose(trace_cpu, trace_cuda) + + +@pytest.mark.parametrize("shape", shapes) +def test_median_filter(shape): + x = torch.randn(*shape) + + for filter_width in [3, 5, 7, 13]: + filtered = median_filter(x, filter_width) + + # using np.pad to reflect-pad, because Scipy's behavior is different near the edges. + pad_width = filter_width // 2 + padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect") + scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width]) + scipy_filtered = scipy_filtered[..., pad_width:-pad_width] + + assert np.allclose(filtered, scipy_filtered) + + +@pytest.mark.requires_cuda +@pytest.mark.parametrize("shape", shapes) +def test_median_filter_equivalence(shape): + x = torch.randn(*shape) + + for filter_width in [3, 5, 7, 13]: + filtered_cpu = median_filter(x, filter_width) + filtered_gpu = median_filter(x.cuda(), filter_width).cpu() + + assert np.allclose(filtered_cpu, filtered_gpu) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index f5d66c37..9802f734 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -13,10 +13,22 @@ def test_transcribe(model_name: str): audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") language = "en" if model_name.endswith(".en") else None - result = model.transcribe(audio_path, language=language, temperature=0.0) + result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True) assert result["language"] == "en" transcription = result["text"].lower() assert "my fellow americans" in transcription assert "your country" in transcription assert "do for you" in transcription + + timing_checked = False + for segment in result["segments"]: + for timing in segment["words"]: + assert timing["start"] < timing["end"] + if timing["word"].strip(" ,") == "Americans": + assert timing["start"] <= 1.8 + assert timing["end"] >= 1.8 + print(timing) + timing_checked = True + + assert timing_checked diff --git a/whisper/__init__.py b/whisper/__init__.py index cb334065..26d1e0ea 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -29,6 +29,23 @@ "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", } +# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are +# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. +_ALIGNMENT_HEADS = { + "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", + "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", + "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", + "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", + "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", + "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', + "large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', +} + + def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) @@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow if name in _MODELS: checkpoint_file = _download(_MODELS[name], download_root, in_memory) + alignment_heads = _ALIGNMENT_HEADS[name] elif os.path.isfile(name): checkpoint_file = open(name, "rb").read() if in_memory else name + alignment_heads = None else: raise RuntimeError(f"Model {name} not found; available models = {available_models()}") @@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow model = Whisper(dims) model.load_state_dict(checkpoint["model_state_dict"]) + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + return model.to(device) diff --git a/whisper/audio.py b/whisper/audio.py index de8a1951..964d4157 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -18,6 +18,10 @@ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each) +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each) + def load_audio(file: str, sr: int = SAMPLE_RATE): """ diff --git a/whisper/model.py b/whisper/model.py index be73a4a8..a1ab2e34 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -1,3 +1,5 @@ +import base64 +import gzip from dataclasses import dataclass from typing import Dict from typing import Iterable, Optional @@ -8,8 +10,8 @@ from torch import Tensor from torch import nn -from .transcribe import transcribe as transcribe_function from .decoding import detect_language as detect_language_function, decode as decode_function +from .transcribe import transcribe as transcribe_function @dataclass @@ -213,6 +215,15 @@ def __init__(self, dims: ModelDimensions): self.dims.n_text_head, self.dims.n_text_layer, ) + # use the last half layers for alignment by default; see `set_alignment_heads()` below + all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool) + all_heads[self.dims.n_text_layer // 2:] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + + def set_alignment_heads(self, dump: bytes): + array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() + mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head) + self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) def embed_audio(self, mel: torch.Tensor): return self.encoder(mel) diff --git a/whisper/timing.py b/whisper/timing.py new file mode 100644 index 00000000..98927aa0 --- /dev/null +++ b/whisper/timing.py @@ -0,0 +1,305 @@ +import subprocess +import warnings +from dataclasses import dataclass +from typing import List, TYPE_CHECKING + +import numba +import numpy as np +import torch +import torch.nn.functional as F + +from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND +from .tokenizer import Tokenizer + +if TYPE_CHECKING: + from .model import Whisper + + +def median_filter(x: torch.Tensor, filter_width: int): + """Apply a median filter of width `filter_width` along the last dimension of `x`""" + pad_width = filter_width // 2 + if x.shape[-1] <= pad_width: + # F.pad requires the padding width to be smaller than the input dimension + return x + + if (ndim := x.ndim) <= 2: + # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D + x = x[None, None, :] + + assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" + + result = None + x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") + if x.is_cuda: + try: + from .triton_ops import median_filter_cuda + result = median_filter_cuda(x, filter_width) + except (RuntimeError, subprocess.CalledProcessError): + warnings.warn( + "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " + "falling back to a slower median kernel implementation..." + ) + + if result is None: + # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) + result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] + + if ndim <= 2: + result = result[0, 0] + + return result + +@numba.jit +def backtrace(trace: np.ndarray): + i = trace.shape[0] - 1 + j = trace.shape[1] - 1 + trace[0, :] = 2 + trace[:, 0] = 1 + + result = [] + while i > 0 or j > 0: + result.append((i - 1, j - 1)) + + if trace[i, j] == 0: + i -= 1 + j -= 1 + elif trace[i, j] == 1: + i -= 1 + elif trace[i, j] == 2: + j -= 1 + else: + raise ValueError("Unexpected trace[i, j]") + + result = np.array(result) + return result[::-1, :].T + + +@numba.jit(nopython=True, parallel=True) +def dtw_cpu(x: np.ndarray): + N, M = x.shape + cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf + trace = -np.ones((N + 1, M + 1), dtype=np.float32) + + cost[0, 0] = 0 + for j in range(1, M + 1): + for i in range(1, N + 1): + c0 = cost[i - 1, j - 1] + c1 = cost[i - 1, j] + c2 = cost[i, j - 1] + + if c0 < c1 and c0 < c2: + c, t = c0, 0 + elif c1 < c0 and c1 < c2: + c, t = c1, 1 + else: + c, t = c2, 2 + + cost[i, j] = x[i - 1, j - 1] + c + trace[i, j] = t + + return backtrace(trace) + + +def dtw_cuda(x, BLOCK_SIZE=1024): + from .triton_ops import dtw_kernel + + M, N = x.shape + assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" + + x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) + x_skew = x_skew.T.contiguous() + cost = torch.ones(N + M + 2, M + 2) * np.inf + cost[0, 0] = 0 + cost = cost.cuda() + trace = torch.zeros_like(cost, dtype=torch.int32) + + dtw_kernel[(1,)]( + cost, + trace, + x_skew, + x_skew.stride(0), + cost.stride(0), + trace.stride(0), + N, + M, + BLOCK_SIZE=BLOCK_SIZE + ) + + trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1] + return backtrace(trace.cpu().numpy()) + + +def dtw(x: torch.Tensor) -> np.ndarray: + if x.is_cuda: + try: + return dtw_cuda(x) + except (RuntimeError, subprocess.CalledProcessError): + warnings.warn( + "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " + "falling back to a slower DTW implementation..." + ) + + return dtw_cpu(x.double().cpu().numpy()) + + +@dataclass +class WordTiming: + word: str + tokens: List[int] + start: float + end: float + probability: float + + +def find_alignment( + model: "Whisper", + tokenizer: Tokenizer, + text_tokens: List[int], + mel: torch.Tensor, + num_frames: int, + *, + medfilt_width: int = 7, + qk_scale: float = 1.0, +) -> List[WordTiming]: + tokens = torch.tensor( + [ + *tokenizer.sot_sequence, + tokenizer.no_timestamps, + *text_tokens, + tokenizer.eot, + ] + ).to(model.device) + + # install hooks on the cross attention layers to retrieve the attention weights + QKs = [None] * model.dims.n_text_layer + hooks = [ + block.cross_attn.register_forward_hook( + lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) + ) + for i, block in enumerate(model.decoder.blocks) + ] + + with torch.no_grad(): + logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] + token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1) + text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist() + + for hook in hooks: + hook.remove() + + # heads * tokens * frames + weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T]) + weights = weights[:, :, : num_frames // 2] + weights = (weights * qk_scale).softmax(dim=-1) + std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) + weights = (weights - mean) / std + weights = median_filter(weights, medfilt_width) + + matrix = weights.mean(axis=0) + matrix = matrix[len(tokenizer.sot_sequence):-1] + text_indices, time_indices = dtw(-matrix) + + words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] / TOKENS_PER_SECOND + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + # hack: ensure the first and second word is not longer than twice the median word duration. + # a better segmentation algorithm based on VAD should be able to replace this. + word_durations = end_times - start_times + word_durations = word_durations[word_durations.nonzero()] + if len(word_durations) > 0: + median_duration = np.median(word_durations) + max_duration = median_duration * 2 + if len(word_durations) >= 2 and word_durations[1] > max_duration: + end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration) + if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration: + start_times[0] = max(0, end_times[0] - max_duration) + + return [ + WordTiming(word, tokens, start, end, probability) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + + +def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if previous.word.startswith(" ") and previous.word.strip() in prepended: + # prepend it to the following word + following.word = previous.word + following.word + following.tokens = previous.tokens + following.tokens + previous.word = "" + previous.tokens = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous.word.endswith(" ") and following.word in appended: + # append it to the previous word + previous.word = previous.word + following.word + previous.tokens = previous.tokens + following.tokens + following.word = "" + following.tokens = [] + else: + i = j + j += 1 + + +def add_word_timestamps( + *, + segments: List[dict], + model: "Whisper", + tokenizer: Tokenizer, + mel: torch.Tensor, + num_frames: int, + prepend_punctuations: str = "\"\'“¿([{-", + append_punctuations: str = "\"\'.。,,!!??::”)]}、", + **hyperparams, +): + if len(segments) == 0: + return + + text_tokens = [t for segment in segments for t in segment["tokens"]] + alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams) + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + + time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE + token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) + + for segment in segments: + segment["words"] = [] + + word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0)) + for i, timing in enumerate(alignment): + if timing.word: + segment = segments[token_sources[word_boundaries[i]]] + start = round(time_offset + timing.start, 2) + end = round(time_offset + timing.end, 2) + segment["words"].append( + dict(word=timing.word, start=start, end=end, probability=timing.probability) + ) + + for segment in segments: + if len(words := segment["words"]) > 0: + # adjust the segment-level timestamps based on the word-level timestamps + segment["start"] = words[0]["start"] + segment["end"] = words[-1]["end"] diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index 7efa2d4a..ea1117f7 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -1,4 +1,5 @@ import os +import string from dataclasses import dataclass from functools import lru_cache, cached_property from typing import List, Optional, Tuple, Union @@ -265,6 +266,48 @@ def _get_single_token_id(self, text) -> int: assert len(tokens) == 1, f"{text} is not encoded as a single token" return tokens[0] + def split_to_word_tokens(self, tokens: List[int]): + if self.language in {"zh", "ja", "th", "lo", "my"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + + return self.split_tokens_on_spaces(tokens) + + def split_tokens_on_unicode(self, tokens: List[int]): + words = [] + word_tokens = [] + current_tokens = [] + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + if "\ufffd" not in decoded: + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + + return words, word_tokens + + def split_tokens_on_spaces(self, tokens: List[int]): + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + if special or with_space or punctuation or len(words) == 0: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens + @lru_cache(maxsize=None) def build_tokenizer(name: str = "gpt2"): diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d155b20d..d7c04874 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -7,8 +7,9 @@ import torch import tqdm -from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram +from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult +from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer @@ -27,6 +28,9 @@ def transcribe( no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = "\"\'“¿([{-", + append_punctuations: str = "\"\'.。,,!!??::”)]}、", **decode_options, ): """ @@ -63,6 +67,21 @@ def transcribe( disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + word_timestamps: bool + Extract word-level timestamps using the cross-attention pattern and dynamic time warping, + and include the timestamps for each word in each segment. + + prepend_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the next word + + append_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the previous word + + initial_prompt: Optional[str] + Optional text to provide as a prompt for the first window. This can be used to provide, or + "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns + to make it more likely to predict those word correctly. + decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -90,16 +109,19 @@ def transcribe( else: if verbose: print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") - segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(segment) + mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) + _, probs = model.detect_language(mel_segment) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") - language = decode_options["language"] - task = decode_options.get("task", "transcribe") + language: str = decode_options["language"] + task: str = decode_options.get("task", "transcribe") tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + if word_timestamps and task == "translate": + warnings.warn("Word-level timestamps on translations may not be reliable.") + def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature decode_result = None @@ -145,42 +167,35 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: else: initial_prompt_tokens = [] - def add_segment( - *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult + def new_segment( + *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult ): - text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) - if len(text.strip()) == 0: # skip empty text output - return - - all_segments.append( - { - "id": len(all_segments), - "seek": seek, - "start": start, - "end": end, - "text": text, - "tokens": text_tokens.tolist(), - "temperature": result.temperature, - "avg_logprob": result.avg_logprob, - "compression_ratio": result.compression_ratio, - "no_speech_prob": result.no_speech_prob, - } - ) - if verbose: - print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")) - - # show the progress bar when verbose is False (otherwise the transcribed text will be printed) + text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot] + return { + "id": len(all_segments), + "seek": seek, + "start": start, + "end": end, + "text": tokenizer.decode(text_tokens), + "tokens": text_tokens, + "temperature": result.temperature, + "avg_logprob": result.avg_logprob, + "compression_ratio": result.compression_ratio, + "no_speech_prob": result.no_speech_prob, + } + + # show the progress bar when verbose is False (if True, transcribed text will be printed) num_frames = mel.shape[-1] - previous_seek_value = seek - with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: while seek < num_frames: - timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) - segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE + time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + mel_segment = mel[:, seek:] + segment_size = min(mel_segment.shape[-1], N_FRAMES) + segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE + mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(segment) + result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) if no_speech_threshold is not None: @@ -191,29 +206,36 @@ def add_segment( should_skip = False if should_skip: - seek += segment.shape[-1] # fast-forward to the next segment boundary + seek += segment_size # fast-forward to the next segment boundary continue + previous_seek = seek + current_segments = [] + current_tokens = [] + timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]: consecutive = consecutive.tolist() + [len(tokens)] + last_slice = 0 for current_slice in consecutive: sliced_tokens = tokens[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin - add_segment( - start=timestamp_offset + start_timestamp_pos * time_precision, - end=timestamp_offset + end_timestamp_pos * time_precision, - text_tokens=sliced_tokens[1:-1], + current_segments.append(new_segment( + start=time_offset + start_timestamp_pos * time_precision, + end=time_offset + end_timestamp_pos * time_precision, + tokens=sliced_tokens, result=result, - ) + )) + current_tokens.append(sliced_tokens.tolist()) last_slice = current_slice + if ended_with_single_timestamp: # single timestamp at the end means no speech after the last timestamp. - seek += segment.shape[-1] + seek += segment_size else: # otherwise, ignore the unfinished segment and seek to the last timestamp last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin @@ -227,23 +249,54 @@ def add_segment( last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin duration = last_timestamp_pos * time_precision - add_segment( - start=timestamp_offset, - end=timestamp_offset + duration, - text_tokens=tokens, + current_segments.append(new_segment( + start=time_offset, + end=time_offset + duration, + tokens=tokens, result=result, - ) - - seek += segment.shape[-1] - all_tokens.extend(tokens.tolist()) + )) + current_tokens.append(tokens.tolist()) + seek += segment_size if not condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used prompt_reset_since = len(all_tokens) + if word_timestamps: + add_word_timestamps( + segments=current_segments, + model=model, + tokenizer=tokenizer, + mel=mel_segment, + num_frames=segment_size, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + ) + word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] + if len(consecutive) > 0 and len(word_end_timestamps) > 0: + seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND) + if seek_shift > 0: + seek = previous_seek + seek_shift + + if verbose: + for segment in current_segments: + start, end, text = segment["start"], segment["end"], segment["text"] + line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" + print(make_safe(line)) + + # if a segment is instantaneous or does not contain text, clear it + for i, segment in enumerate(current_segments): + if segment["start"] == segment["end"] or segment["text"].strip() == "": + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] + current_tokens[i] = [] + + all_segments.extend(current_segments) + all_tokens.extend([token for segment in current_tokens for token in segment]) + # update progress bar - pbar.update(min(num_frames, seek) - previous_seek_value) - previous_seek_value = seek + pbar.update(min(num_frames, seek) - previous_seek) return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), @@ -282,6 +335,9 @@ def cli(): parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") + parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") + parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") + parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") args = parser.parse_args().__dict__ diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py new file mode 100644 index 00000000..d829e204 --- /dev/null +++ b/whisper/triton_ops.py @@ -0,0 +1,92 @@ +import math + +import numpy as np +import torch +from functools import lru_cache + +try: + import triton + import triton.language as tl +except ImportError: + raise RuntimeError("triton import failed; try `pip install --pre triton`") + + +@triton.jit +def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < M + + for k in range(1, N + M + 1): # k = i + j + tl.debug_barrier() + + p0 = cost + (k - 1) * cost_stride + p1 = cost + k * cost_stride + p2 = cost + k * cost_stride + 1 + + c0 = tl.load(p0 + offsets, mask=mask) + c1 = tl.load(p1 + offsets, mask=mask) + c2 = tl.load(p2 + offsets, mask=mask) + + x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0) + cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2) + + cost_ptr = cost + (k + 1) * cost_stride + 1 + tl.store(cost_ptr + offsets, cost_row, mask=mask) + + trace_ptr = trace + (k + 1) * trace_stride + 1 + tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) + tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) + tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) + + +@lru_cache(maxsize=None) +def median_kernel(filter_width: int): + @triton.jit + def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width + row_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < y_stride + + x_ptr = x + row_idx * x_stride + y_ptr = y + row_idx * y_stride + + LOAD_ALL_ROWS_HERE + + BUBBLESORT_HERE + + tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) + + kernel = triton.JITFunction(kernel.fn) + kernel.src = kernel.src.replace(" LOAD_ALL_ROWS_HERE", "\n".join([ + f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" + for i in range(filter_width) + ])) + kernel.src = kernel.src.replace(" BUBBLESORT_HERE", "\n\n".join([ + "\n\n".join([ + "\n".join([ + f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", + f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", + f" row{j} = smaller", + f" row{j + 1} = larger", + ]) + for j in range(filter_width - i - 1) + ]) + for i in range(filter_width // 2 + 1) + ])) + kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") + + return kernel + + +def median_filter_cuda(x: torch.Tensor, filter_width: int): + """Apply a median filter of given width along the last dimension of x""" + slices = x.contiguous().unfold(-1, filter_width, 1) + grid = np.prod(slices.shape[:-2]) + + kernel = median_kernel(filter_width) + y = torch.empty_like(slices[..., 0]) + + BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length() + kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE) + + return y diff --git a/whisper/utils.py b/whisper/utils.py index 5dacc173..8ee91293 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -85,34 +85,63 @@ def write_result(self, result: dict, file: TextIO): print(segment['text'].strip(), file=file, flush=True) -class WriteVTT(ResultWriter): +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str + + def iterate_result(self, result: dict): + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment['text'].strip().replace('-->', '->') + + if word_timings := segment.get("words", None): + all_words = [timing["word"] for timing in word_timings] + all_words[0] = all_words[0].strip() # remove the leading space, if any + last = segment_start + for i, this_word in enumerate(word_timings): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, segment_text + + yield start, end, "".join( + [f"{word}" if j == i else word for j, word in enumerate(all_words)] + ) + last = end + + if last != segment_end: + yield last, segment_end, segment_text + else: + yield segment_start, segment_end, segment_text + + def format_timestamp(self, seconds: float): + return format_timestamp( + seconds=seconds, + always_include_hours=self.always_include_hours, + decimal_marker=self.decimal_marker, + ) + + +class WriteVTT(SubtitlesWriter): extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = '.' def write_result(self, result: dict, file: TextIO): print("WEBVTT\n", file=file) - for segment in result["segments"]: - print( - f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) + for start, end, text in self.iterate_result(result): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) -class WriteSRT(ResultWriter): +class WriteSRT(SubtitlesWriter): extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = ',' def write_result(self, result: dict, file: TextIO): - for i, segment in enumerate(result["segments"], start=1): - # write srt lines - print( - f"{i}\n" - f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " - f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) + for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) class WriteTSV(ResultWriter):