-
Notifications
You must be signed in to change notification settings - Fork 7.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
word-level timestamps in transcribe()
#869
Changes from 14 commits
8f9357f
46ea501
cfd2b81
742d2f4
fb12414
80331c0
1d2ed66
b61e8f4
54f2901
8ce6277
812f446
cd5191f
f64d8bc
89133bd
d4f9399
040aa04
8e2756b
6c431c4
ff6cbfd
5fa4356
48537aa
8eb29c3
6ed4c11
31cd418
145f325
2b079c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
numba | ||
numpy | ||
torch | ||
tqdm | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import random as rand | ||
|
||
import numpy | ||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def random(): | ||
rand.seed(42) | ||
numpy.random.seed(42) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import pytest | ||
import numpy as np | ||
import scipy.ndimage | ||
import torch | ||
|
||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter | ||
|
||
|
||
sizes = [ | ||
(10, 20), (32, 16), (123, 1500), (234, 189), | ||
] | ||
shapes = [ | ||
(4, 5, 20, 345), (6, 12, 240, 512), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("N, M", sizes) | ||
def test_dtw(N: int, M: int): | ||
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)]) | ||
np.random.shuffle(steps) | ||
x = np.random.random((N, M)).astype(np.float32) | ||
|
||
i, j, k = 0, 0, 0 | ||
trace = [] | ||
while True: | ||
x[i, j] -= 1 | ||
trace.append((i, j)) | ||
|
||
if k == len(steps): | ||
break | ||
|
||
if k + 1 < len(steps) and steps[k] != steps[k + 1]: | ||
i += 1 | ||
j += 1 | ||
k += 2 | ||
continue | ||
|
||
if steps[k] == 0: | ||
i += 1 | ||
if steps[k] == 1: | ||
j += 1 | ||
k += 1 | ||
|
||
trace = np.array(trace).T | ||
dtw_trace = dtw_cpu(x) | ||
|
||
assert np.allclose(trace, dtw_trace) | ||
|
||
|
||
@pytest.mark.requires_cuda | ||
@pytest.mark.parametrize("N, M", sizes) | ||
def test_dtw_cuda_equivalence(N: int, M: int): | ||
x_numpy = np.random.randn(N, M).astype(np.float32) | ||
x_cuda = torch.from_numpy(x_numpy).cuda() | ||
|
||
trace_cpu = dtw_cpu(x_numpy) | ||
trace_cuda = dtw_cuda(x_cuda) | ||
|
||
assert np.allclose(trace_cpu, trace_cuda) | ||
|
||
|
||
@pytest.mark.parametrize("shape", shapes) | ||
def test_median_filter(shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: Is there a licensing issue using |
||
x = torch.randn(*shape) | ||
|
||
for filter_width in [3, 5, 7, 13]: | ||
filtered = median_filter(x, filter_width) | ||
scipy_filtered = scipy.ndimage.median_filter(x, (1, 1, 1, filter_width), mode="nearest") | ||
|
||
assert np.allclose(filtered, scipy_filtered) | ||
|
||
|
||
@pytest.mark.requires_cuda | ||
@pytest.mark.parametrize("shape", shapes) | ||
def test_median_filter_equivalence(shape): | ||
x = torch.randn(*shape) | ||
|
||
for filter_width in [3, 5, 7, 13]: | ||
filtered_cpu = median_filter(x, filter_width) | ||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu() | ||
|
||
assert np.allclose(filtered_cpu, filtered_gpu) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
from typing import List, TYPE_CHECKING | ||
|
||
import numba | ||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND | ||
from .tokenizer import Tokenizer | ||
|
||
if TYPE_CHECKING: | ||
from .model import Whisper | ||
|
||
|
||
def median_filter(x: torch.Tensor, filter_width: int): | ||
"""Apply a median filter of width `filter_width` along the last dimension of `x`""" | ||
assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors" | ||
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" | ||
|
||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode='replicate') | ||
if x.is_cuda: | ||
from .triton_ops import median_filter_cuda | ||
return median_filter_cuda(x, filter_width) | ||
|
||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) | ||
return x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] | ||
|
||
|
||
@numba.jit | ||
def backtrace(trace: np.ndarray): | ||
i = trace.shape[0] - 1 | ||
j = trace.shape[1] - 1 | ||
trace[0, :] = 2 | ||
trace[:, 0] = 1 | ||
|
||
result = [] | ||
while i > 0 or j > 0: | ||
result.append((i - 1, j - 1)) | ||
|
||
if trace[i, j] == 0: | ||
i -= 1 | ||
j -= 1 | ||
elif trace[i, j] == 1: | ||
i -= 1 | ||
elif trace[i, j] == 2: | ||
j -= 1 | ||
else: | ||
raise ValueError("Unexpected trace[i, j]") | ||
|
||
result = np.array(result) | ||
return result[::-1, :].T | ||
|
||
|
||
@numba.jit(nopython=True, parallel=True) | ||
def dtw_cpu(x: np.ndarray): | ||
N, M = x.shape | ||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf | ||
trace = -np.ones((N + 1, M + 1), dtype=np.float32) | ||
|
||
cost[0, 0] = 0 | ||
for j in range(1, M + 1): | ||
for i in range(1, N + 1): | ||
c0 = cost[i - 1, j - 1] | ||
c1 = cost[i - 1, j] | ||
c2 = cost[i, j - 1] | ||
Comment on lines
+87
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe one of these two is not needed, as it doesn't really make sense to attribute a same timestamp to several tokens. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it totally makes sense to force a token to have at least one timestamp, which is only about 20 milliseconds. I left this as-is, to handle some failure cases like repetition looping as you mentioned; in the post-processing zero-length segments are removed, and it was usually the case for the generation got stuck on repetition looping. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have seen words repeating with same end and start timestamps. Segments were fine. I think we have to move the segment code where you handle duplication ( repetition looping ) to words too. |
||
|
||
if c0 < c1 and c0 < c2: | ||
c, t = c0, 0 | ||
elif c1 < c0 and c1 < c2: | ||
c, t = c1, 1 | ||
else: | ||
c, t = c2, 2 | ||
|
||
cost[i, j] = x[i - 1, j - 1] + c | ||
trace[i, j] = t | ||
|
||
return backtrace(trace) | ||
|
||
|
||
def dtw_cuda(x, BLOCK_SIZE=1024): | ||
from .triton_ops import dtw_kernel | ||
|
||
M, N = x.shape | ||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" | ||
|
||
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) | ||
x_skew = x_skew.T.contiguous() | ||
cost = torch.ones(N + M + 2, M + 2) * np.inf | ||
cost[0, 0] = 0 | ||
cost = cost.cuda() | ||
trace = torch.zeros_like(cost, dtype=torch.int32) | ||
|
||
dtw_kernel[(1,)]( | ||
cost, | ||
trace, | ||
x_skew, | ||
x_skew.stride(0), | ||
cost.stride(0), | ||
trace.stride(0), | ||
N, | ||
M, | ||
BLOCK_SIZE=BLOCK_SIZE | ||
) | ||
|
||
trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1] | ||
return backtrace(trace.cpu().numpy()) | ||
|
||
|
||
def dtw(x: torch.Tensor) -> np.ndarray: | ||
if x.is_cuda: | ||
return dtw_cuda(x) | ||
|
||
return dtw_cpu(x.double().cpu().numpy()) | ||
|
||
|
||
def add_word_timestamps( | ||
model: "Whisper", | ||
tokenizer: Tokenizer, | ||
mel: torch.Tensor, | ||
num_frames: int, | ||
segments: List[dict], | ||
*, | ||
medfilt_width: int = 7, | ||
qk_scale: float = 1.0, | ||
): | ||
if len(segments) == 0: | ||
return | ||
|
||
# install hooks on the cross attention layers to retrieve the attention weights | ||
QKs = [None] * model.dims.n_text_layer | ||
hooks = [ | ||
block.cross_attn.register_forward_hook( | ||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1]) | ||
) | ||
for i, block in enumerate(model.decoder.blocks) | ||
] | ||
|
||
tokens = torch.tensor( | ||
[ | ||
*tokenizer.sot_sequence, | ||
tokenizer.timestamp_begin, | ||
*[t for segment in segments for t in segment["tokens"]], | ||
tokenizer.timestamp_begin + mel.shape[-1] // 2, | ||
tokenizer.eot, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't thing these last two tokens are needed to estimate word timestamps. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a trick to nudge the DTW path to go along these tokens so that the last few words have more accurate timestamp. It's still not perfect, but I settled with using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that it's important to have attention weights used to predict the timestamp token for the end of speech segment, but these attention weights are the one you get when the input is the last predicted (sub)word token. I think it's enough. When the input token is the final timestamp the decoder is already focusing on predicting the next thing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are anecdotally seeing that too, in our tests. The timestamps lag a word for example Got no empirical proof ( very anecdotal ). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see the same observations. I did notice in earlier commits that the next token after a comma may lag as if the comma was taking up too much time. That seems to have become more accurate in later commits. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, we could not come with any empirical evidence either. May it was the previous version. |
||
] | ||
).to(model.device) | ||
|
||
with torch.no_grad(): | ||
model(mel.unsqueeze(0), tokens.unsqueeze(0)) | ||
|
||
for hook in hooks: | ||
hook.remove() | ||
|
||
weights = torch.cat(QKs[-6:]) # layers * heads * tokens * frames | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why only considering (at most) the last 6 layers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was because the attention weights in the later layers were more indicative of the time alignment. I've updated this part, and now it uses a mask to select which layers and heads to find the alignment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok interesting. I have to check this masking trick. |
||
weights = weights[:, :, :, : num_frames // 2] | ||
weights = median_filter(weights, medfilt_width) | ||
weights = (weights * qk_scale).softmax(dim=-1) | ||
weights = weights / weights.norm(dim=-2, keepdim=True) | ||
matrix = weights.mean(axis=(0, 1)).neg() | ||
|
||
text_indices, time_indices = dtw(matrix) | ||
|
||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) | ||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND | ||
|
||
if tokenizer.language in {"zh", "ja", "th", "lo", "my"}: | ||
# These languages don't typically use spaces, so it is difficult to split words | ||
# without morpheme analysis. Here, we instead split words at any | ||
# position where the tokens are decoded as valid unicode points | ||
split_tokens = tokenizer.split_tokens_on_unicode | ||
else: | ||
split_tokens = tokenizer.split_tokens_on_spaces | ||
|
||
words, word_tokens = split_tokens(tokens[1:].tolist()) | ||
|
||
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) | ||
token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources) | ||
|
||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE | ||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) | ||
start_times = time_offset + jump_times[word_boundaries[:-1]] | ||
end_times = time_offset + jump_times[word_boundaries[1:]] | ||
|
||
for segment in segments: | ||
segment["words"] = [] | ||
|
||
for i, (word, start, end) in enumerate(zip(words, start_times, end_times)): | ||
if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand | ||
continue | ||
|
||
segment = segments[token_sources[word_boundaries[i]]] | ||
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to add a confidence score based on the average log proba for each words? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great point! Added in 5fa4356 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome 👍 |
||
|
||
# adjust the segment-level timestamps based on the word-level timestamps | ||
for segment in segments: | ||
if len(segment["words"]) > 0: | ||
segment["start"] = segment["words"][0]["start"] | ||
segment["end"] = segment["words"][-1]["end"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was your reason to not use the dtw library licensing concerns or just speedup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtw-python is GPL, as mentioned here -
#869 (comment)