-
Notifications
You must be signed in to change notification settings - Fork 7.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
word-level timestamps in transcribe()
#869
Changes from 1 commit
8f9357f
46ea501
cfd2b81
742d2f4
fb12414
80331c0
1d2ed66
b61e8f4
54f2901
8ce6277
812f446
cd5191f
f64d8bc
89133bd
d4f9399
040aa04
8e2756b
6c431c4
ff6cbfd
5fa4356
48537aa
8eb29c3
6ed4c11
31cd418
145f325
2b079c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
timing.py
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import List, TYPE_CHECKING | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND | ||
from .tokenizer import Tokenizer | ||
|
||
if TYPE_CHECKING: | ||
from .model import Whisper | ||
|
||
|
||
def median_filter(x: torch.Tensor, filter_width: int): | ||
"""Apply a median filter of width `filter_width` along the last dimension of `x`""" | ||
assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors" | ||
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" | ||
|
||
padded = F.pad(x, (0, 0, filter_width // 2, filter_width // 2), mode='replicate') | ||
slices = padded.unfold(-1, filter_width, 1) | ||
return slices.median(dim=-1).values | ||
|
||
|
||
def add_word_timestamps( | ||
model: "Whisper", | ||
tokenizer: Tokenizer, | ||
mel: torch.Tensor, | ||
num_frames: int, | ||
segments: List[dict], | ||
*, | ||
medfilt_width: int = 7, | ||
qk_scale: float = 1.0, | ||
): | ||
if len(segments) == 0: | ||
return | ||
|
||
from dtw import dtw | ||
|
||
# install hooks on the cross attention layers to retrieve the attention weights | ||
QKs = [None] * model.dims.n_text_layer | ||
hooks = [ | ||
block.cross_attn.register_forward_hook( | ||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1]) | ||
) | ||
for i, block in enumerate(model.decoder.blocks) | ||
] | ||
|
||
tokens = torch.tensor( | ||
[ | ||
*tokenizer.sot_sequence, | ||
tokenizer.timestamp_begin, | ||
*[t for segment in segments for t in segment["tokens"]], | ||
tokenizer.timestamp_begin + mel.shape[-1] // 2, | ||
tokenizer.eot, | ||
] | ||
).to(model.device) | ||
|
||
with torch.no_grad(): | ||
model(mel.unsqueeze(0), tokens.unsqueeze(0)) | ||
|
||
for hook in hooks: | ||
hook.remove() | ||
|
||
weights = torch.cat(QKs) # layers * heads * tokens * frames | ||
weights = weights[:, :, :, : num_frames // 2] | ||
weights = median_filter(weights, medfilt_width) | ||
weights = (weights * qk_scale).softmax(dim=-1) | ||
|
||
w = weights / weights.norm(dim=-2, keepdim=True) | ||
matrix = w.mean(axis=(0, 1)).neg().double().cpu().numpy() | ||
|
||
alignment = dtw(matrix) | ||
|
||
jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool) | ||
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND | ||
|
||
if tokenizer.language in {"zh", "ja", "th", "lo", "my"}: | ||
# These languages don't typically use spaces, so it is difficult to split words | ||
# without morpheme analysis. Here, we instead split words at any | ||
# position where the tokens are decoded as valid unicode points | ||
split_tokens = tokenizer.split_tokens_on_unicode | ||
else: | ||
split_tokens = tokenizer.split_tokens_on_spaces | ||
|
||
words, word_tokens = split_tokens(tokens[1:].tolist()) | ||
|
||
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) | ||
token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources) | ||
|
||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE | ||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0)) | ||
start_times = time_offset + jump_times[word_boundaries[:-1]] | ||
end_times = time_offset + jump_times[word_boundaries[1:]] | ||
|
||
for segment in segments: | ||
segment["words"] = [] | ||
|
||
for i, (word, start, end) in enumerate(zip(words, start_times, end_times)): | ||
if word.startswith("<|") or word.strip() in ".,!?、。": | ||
continue | ||
|
||
segment = segments[token_sources[word_boundaries[i]]] | ||
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to add a confidence score based on the average log proba for each words? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great point! Added in 5fa4356 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome 👍 |
||
|
||
# adjust the segment-level timestamps based on the word-level timestamps | ||
for segment in segments: | ||
if len(segment["words"]) > 0: | ||
segment["start"] = segment["words"][0]["start"] | ||
segment["end"] = segment["words"][-1]["end"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't thing these last two tokens are needed to estimate word timestamps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a trick to nudge the DTW path to go along these tokens so that the last few words have more accurate timestamp. It's still not perfect, but I settled with using
<|no_timestamps|>
token and no timestamp tokens in the recent commit.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that it's important to have attention weights used to predict the timestamp token for the end of speech segment, but these attention weights are the one you get when the input is the last predicted (sub)word token. I think it's enough. When the input token is the final timestamp the decoder is already focusing on predicting the next thing.
I wonder if things are not shifted by one, because it was a problem I saw with your notebook (the timestamps were assigned to the token before the one it should be).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are anecdotally seeing that too, in our tests. The timestamps lag a word for example Got no empirical proof ( very anecdotal ).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the same observations.
I did notice in earlier commits that the next token after a comma may lag as if the comma was taking up too much time. That seems to have become more accurate in later commits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, we could not come with any empirical evidence either. May it was the previous version.