Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Add word timestamps and confidence scores #201

Merged
merged 8 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Speech recognition with Whisper in MLX. Whisper is a set of open source speech
recognition models from OpenAI, ranging from 39 million to 1.5 billion
parameters[^1].
parameters.[^1]

### Setup

Expand All @@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/):
brew install ffmpeg
```

Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use:
Next, download the Whisper PyTorch checkpoint and convert the weights to the
MLX format. For example, to convert the `tiny` model use:

```
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
Expand All @@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there.

Transcribe audio with:

```
```python
import whisper

text = whisper.transcribe(speech_file)["text"]
```

The `transcribe` function also supports word-level timestamps. You can generate
these with:

```python
output = whisper.transcribe(speech_file, word_timestamps=True)
print(output["segments"][0]["words"])
```

To see more transcription options use:

```
>>> help(whisper.transcribe)
```

[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.
4 changes: 4 additions & 0 deletions whisper/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def convert_rblock(model, rules):
mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)

if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())

return mlx_model


Expand Down
131 changes: 131 additions & 0 deletions whisper/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,137 @@ def check_segment(seg, expected):
check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73)

def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
TEST_AUDIO,
model_path=MLX_FP16_MODEL_PATH,
word_timestamps=True,
)

# result predicted with openai-whisper
expected_0 = [
{
"word": " Then",
"start": 0.0,
"end": 0.94,
"probability": 0.855542778968811,
},
{
"word": " the",
"start": 0.94,
"end": 1.12,
"probability": 0.6500106453895569,
},
{
"word": " good",
"start": 1.12,
"end": 1.32,
"probability": 0.5503873825073242,
},
{
"word": " soul",
"start": 1.32,
"end": 1.56,
"probability": 0.46757155656814575,
},
{
"word": " openly",
"start": 1.56,
"end": 2.0,
"probability": 0.9840946793556213,
},
{
"word": " sorted",
"start": 2.0,
"end": 2.38,
"probability": 0.24167272448539734,
},
{
"word": " the",
"start": 2.38,
"end": 2.58,
"probability": 0.9875414967536926,
},
{
"word": " boat",
"start": 2.58,
"end": 2.8,
"probability": 0.5856029391288757,
},
{
"word": " and",
"start": 2.8,
"end": 2.98,
"probability": 0.913351833820343,
},
{
"word": " she",
"start": 2.98,
"end": 3.1,
"probability": 0.9913808703422546,
},
{
"word": " had",
"start": 3.1,
"end": 3.32,
"probability": 0.9952940344810486,
},
{
"word": " buoyed",
"start": 3.32,
"end": 3.58,
"probability": 0.6411589980125427,
},
{
"word": " so",
"start": 3.58,
"end": 3.8,
"probability": 0.9682658314704895,
},
{
"word": " long",
"start": 3.8,
"end": 4.06,
"probability": 0.9953522682189941,
},
{
"word": " in",
"start": 4.06,
"end": 4.26,
"probability": 0.6745936870574951,
},
{
"word": " secret",
"start": 4.26,
"end": 4.56,
"probability": 0.9905064702033997,
},
{
"word": " and",
"start": 4.56,
"end": 4.9,
"probability": 0.856008768081665,
},
{
"word": " bravely",
"start": 4.9,
"end": 5.28,
"probability": 0.8477402329444885,
},
]

def check_words(words, expected_words):
for word, expected_word in zip(words, expected_words):
for k, v in expected_word.items():
if isinstance(v, float):
self.assertAlmostEqual(word[k], v, places=1)
else:
self.assertEqual(word[k], v)

# Randomly check a couple of segments
check_words(result["segments"][0]["words"], expected_0)


class TestAudio(unittest.TestCase):
def test_load(self):
Expand Down
2 changes: 1 addition & 1 deletion whisper/whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]

logits, self.kv_cache = self.model.decoder(
logits, self.kv_cache, _ = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits.astype(mx.float32)
Expand Down
123 changes: 33 additions & 90 deletions whisper/whisper/timing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# Copyright © 2023 Apple Inc.

import itertools
import subprocess
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List

import mlx.core as mx
import numba
import numpy as np
import torch
import torch.nn.functional as F
from scipy import signal

from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
Expand All @@ -18,7 +16,7 @@
from .model import Whisper


def median_filter(x: torch.Tensor, filter_width: int):
def median_filter(x: np.ndarray, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
Expand All @@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int):
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"

result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")

result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)

if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
# todo: more efficient version in mlx
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[
..., pad_width:-pad_width
]

if ndim <= 2:
result = result[0, 0]
Expand Down Expand Up @@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray):
return backtrace(trace)


def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel

M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"

x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)

dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)

trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())


def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)

return dtw_cpu(x.double().cpu().numpy())
def dtw(x: np.ndarray) -> np.ndarray:
# todo: more efficient version in mlx
return dtw_cpu(x)


@dataclass
Expand All @@ -166,7 +113,7 @@ def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
mel: mx.array,
num_frames: int,
*,
medfilt_width: int = 7,
Expand All @@ -175,41 +122,36 @@ def find_alignment(
if len(text_tokens) == 0:
return []

tokens = torch.tensor(
tokens = mx.array(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)

# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]

with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
)

for hook in hooks:
hook.remove()
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
sampled_logits.dtype
)
text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1)
text_token_probs = np.array(text_token_probs)

# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
)
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = mx.softmax(weights * qk_scale, axis=-1)
mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
weights = median_filter(np.array(weights), medfilt_width)

matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
Expand Down Expand Up @@ -281,7 +223,7 @@ def add_word_timestamps(
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
mel: mx.array,
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
Expand All @@ -301,6 +243,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2

# hack: truncate long words at sentence boundaries.
Expand Down
Loading