Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

word-level timestamps in transcribe() #869

Merged
merged 26 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f9357f
word-level timestamps in `transcribe()`
jongwook Jan 20, 2023
46ea501
moving to `timing.py`
jongwook Jan 21, 2023
cfd2b81
Merge branch 'main' into word-level-timestamps
jongwook Jan 21, 2023
742d2f4
numba implementation for dtw, replacing dtw-python
jongwook Jan 22, 2023
fb12414
Merge branch 'main' into word-level-timestamps
jongwook Jan 22, 2023
80331c0
triton implementation for dtw
jongwook Jan 23, 2023
1d2ed66
add test for dtw implementations
jongwook Jan 23, 2023
b61e8f4
triton implementation of median_filter
jongwook Jan 24, 2023
54f2901
a simple word-level timestamps test
jongwook Jan 24, 2023
8ce6277
add scipy as dev dependency
jongwook Jan 24, 2023
812f446
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
cd5191f
installs an older version of Triton if CUDA < 11.4
jongwook Jan 24, 2023
f64d8bc
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
89133bd
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
d4f9399
fix broken merge
jongwook Jan 24, 2023
040aa04
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
8e2756b
loosen nvcc version match regex
jongwook Jan 25, 2023
6c431c4
find_alignment() function
jongwook Jan 25, 2023
ff6cbfd
Merge branch 'main' into word-level-timestamps
jongwook Feb 2, 2023
5fa4356
miscellaneous improvements
jongwook Feb 2, 2023
48537aa
skip median filtering when the input is too small
jongwook Feb 2, 2023
8eb29c3
Expose punctuation options in cli and transcribe() (#973)
ryanheise Feb 16, 2023
6ed4c11
Merge branch 'main' into word-level-timestamps
jongwook Mar 6, 2023
31cd418
fix merge error
jongwook Mar 6, 2023
145f325
fix merge error 2
jongwook Mar 6, 2023
2b079c4
annotating that word_timestamps is experimental
jongwook Mar 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
moving to timing.py
  • Loading branch information
jongwook committed Jan 21, 2023
commit 46ea501da224c5e6b454e33dca99cc82326805b2
109 changes: 109 additions & 0 deletions whisper/timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List, TYPE_CHECKING

import numpy as np
import torch
import torch.nn.functional as F

from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer

if TYPE_CHECKING:
from .model import Whisper


def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors"
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"

padded = F.pad(x, (0, 0, filter_width // 2, filter_width // 2), mode='replicate')
slices = padded.unfold(-1, filter_width, 1)
return slices.median(dim=-1).values


def add_word_timestamps(
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
segments: List[dict],
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
):
if len(segments) == 0:
return

from dtw import dtw

# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
)
for i, block in enumerate(model.decoder.blocks)
]

tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.timestamp_begin,
*[t for segment in segments for t in segment["tokens"]],
tokenizer.timestamp_begin + mel.shape[-1] // 2,
tokenizer.eot,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't thing these last two tokens are needed to estimate word timestamps.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a trick to nudge the DTW path to go along these tokens so that the last few words have more accurate timestamp. It's still not perfect, but I settled with using <|no_timestamps|> token and no timestamp tokens in the recent commit.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that it's important to have attention weights used to predict the timestamp token for the end of speech segment, but these attention weights are the one you get when the input is the last predicted (sub)word token. I think it's enough. When the input token is the final timestamp the decoder is already focusing on predicting the next thing.
I wonder if things are not shifted by one, because it was a problem I saw with your notebook (the timestamps were assigned to the token before the one it should be).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are anecdotally seeing that too, in our tests. The timestamps lag a word for example Got no empirical proof ( very anecdotal ).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the same observations.

I did notice in earlier commits that the next token after a comma may lag as if the comma was taking up too much time. That seems to have become more accurate in later commits.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we could not come with any empirical evidence either. May it was the previous version.

]
).to(model.device)

with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))

for hook in hooks:
hook.remove()

weights = torch.cat(QKs) # layers * heads * tokens * frames
weights = weights[:, :, :, : num_frames // 2]
weights = median_filter(weights, medfilt_width)
weights = (weights * qk_scale).softmax(dim=-1)

w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1)).neg().double().cpu().numpy()

alignment = dtw(matrix)

jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
split_tokens = tokenizer.split_tokens_on_unicode
else:
split_tokens = tokenizer.split_tokens_on_spaces

words, word_tokens = split_tokens(tokens[1:].tolist())

token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources)

time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
start_times = time_offset + jump_times[word_boundaries[:-1]]
end_times = time_offset + jump_times[word_boundaries[1:]]

for segment in segments:
segment["words"] = []

for i, (word, start, end) in enumerate(zip(words, start_times, end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。":
continue

segment = segments[token_sources[word_boundaries[i]]]
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a confidence score based on the average log proba for each words?
This can be a useful feature, available with very little additional computations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! Added in 5fa4356

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome 👍


# adjust the segment-level timestamps based on the word-level timestamps
for segment in segments:
if len(segment["words"]) > 0:
segment["start"] = segment["words"][0]["start"]
segment["end"] = segment["words"][-1]["end"]
112 changes: 15 additions & 97 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,110 +2,24 @@
import os
import sys
import warnings
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
from typing import Optional, Tuple, Union, TYPE_CHECKING

import numpy as np
import torch
import tqdm

from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, \
pad_or_trim
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, \
write_vtt, write_srt

if TYPE_CHECKING:
from .model import Whisper


def add_word_timestamps(
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
segments: List[dict],
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
):
if len(segments) == 0:
return

from dtw import dtw
from scipy.ndimage import median_filter

# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
)
for i, block in enumerate(model.decoder.blocks)
]

tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.timestamp_begin,
*[t for segment in segments for t in segment["tokens"]],
tokenizer.timestamp_begin + mel.shape[-1] // 2,
tokenizer.eot,
]
).to(model.device)

with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))

for hook in hooks:
hook.remove()

weights = torch.cat(QKs) # layers * heads * tokens * frames
weights = weights[:, :, :, : num_frames // 2].cpu()
weights = median_filter(weights, (1, 1, 1, medfilt_width))
weights = torch.tensor(weights * qk_scale).softmax(dim=-1)

w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1))

alignment = dtw(-matrix.double().numpy())

jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
split_tokens = tokenizer.split_tokens_on_unicode
else:
split_tokens = tokenizer.split_tokens_on_spaces

words, word_tokens = split_tokens(tokens[1:].tolist())

token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources)

time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
start_times = time_offset + jump_times[word_boundaries[:-1]]
end_times = time_offset + jump_times[word_boundaries[1:]]

for segment in segments:
segment["words"] = []

for i, (word, start, end) in enumerate(zip(words, start_times, end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。":
continue

segment = segments[token_sources[word_boundaries[i]]]
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))

# adjust the segment-level timestamps based on the word-level timestamps
for segment in segments:
if len(segment["words"]) > 0:
segment["start"] = segment["words"][0]["start"]
segment["end"] = segment["words"][-1]["end"]


def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
Expand All @@ -116,7 +30,7 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
word_level_timestamps: bool = False,
word_timestamps: bool = False,
**decode_options,
):
"""
Expand Down Expand Up @@ -153,6 +67,10 @@ def transcribe(
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.

decode_options: dict
Keyword arguments to construct `DecodingOptions` instances

Expand Down Expand Up @@ -332,7 +250,7 @@ def add_segment(
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)

if word_level_timestamps:
if word_timestamps:
current_segments = all_segments[last_segment_index:]
add_word_timestamps(
model,
Expand All @@ -342,7 +260,7 @@ def add_segment(
segments=current_segments,
)
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
if len(word_end_timestamps) > 0:
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
seek_shift = (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
seek = previous_seek + round(seek_shift)

Expand Down Expand Up @@ -391,7 +309,7 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_level_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")

args = parser.parse_args().__dict__
Expand Down