Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

word-level timestamps in transcribe() #869

Merged
merged 26 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f9357f
word-level timestamps in `transcribe()`
jongwook Jan 20, 2023
46ea501
moving to `timing.py`
jongwook Jan 21, 2023
cfd2b81
Merge branch 'main' into word-level-timestamps
jongwook Jan 21, 2023
742d2f4
numba implementation for dtw, replacing dtw-python
jongwook Jan 22, 2023
fb12414
Merge branch 'main' into word-level-timestamps
jongwook Jan 22, 2023
80331c0
triton implementation for dtw
jongwook Jan 23, 2023
1d2ed66
add test for dtw implementations
jongwook Jan 23, 2023
b61e8f4
triton implementation of median_filter
jongwook Jan 24, 2023
54f2901
a simple word-level timestamps test
jongwook Jan 24, 2023
8ce6277
add scipy as dev dependency
jongwook Jan 24, 2023
812f446
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
cd5191f
installs an older version of Triton if CUDA < 11.4
jongwook Jan 24, 2023
f64d8bc
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
89133bd
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
d4f9399
fix broken merge
jongwook Jan 24, 2023
040aa04
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
8e2756b
loosen nvcc version match regex
jongwook Jan 25, 2023
6c431c4
find_alignment() function
jongwook Jan 25, 2023
ff6cbfd
Merge branch 'main' into word-level-timestamps
jongwook Feb 2, 2023
5fa4356
miscellaneous improvements
jongwook Feb 2, 2023
48537aa
skip median filtering when the input is too small
jongwook Feb 2, 2023
8eb29c3
Expose punctuation options in cli and transcribe() (#973)
ryanheise Feb 16, 2023
6ed4c11
Merge branch 'main' into word-level-timestamps
jongwook Mar 6, 2023
31cd418
fix merge error
jongwook Mar 6, 2023
145f325
fix merge error 2
jongwook Mar 6, 2023
2b079c4
annotating that word_timestamps is experimental
jongwook Mar 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
numba implementation for dtw, replacing dtw-python
  • Loading branch information
jongwook committed Jan 22, 2023
commit 742d2f4c88df1e2b64b2e68e742a6427a4000b4a
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numba
numpy
torch
tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
dtw-python==1.3.0
54 changes: 47 additions & 7 deletions whisper/timing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, TYPE_CHECKING

import numba
import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -21,6 +22,48 @@ def median_filter(x: torch.Tensor, filter_width: int):
return slices.median(dim=-1).values


@numba.jit(nopython=True, parallel=True)
def dtw(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)

i, j = 0, 0
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
Comment on lines +87 to +88

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe one of these two is not needed, as it doesn't really make sense to attribute a same timestamp to several tokens.
Well... not fully sure. Maybe it's useful when a lot of text has to be aligned with a small portion of audio (which can happen when Whisper "inner language model" is stuck).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it totally makes sense to force a token to have at least one timestamp, which is only about 20 milliseconds. I left this as-is, to handle some failure cases like repetition looping as you mentioned; in the post-processing zero-length segments are removed, and it was usually the case for the generation got stuck on repetition looping.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have seen words repeating with same end and start timestamps. Segments were fine. I think we have to move the segment code where you handle duplication ( repetition looping ) to words too.


if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2

cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t

result = []
while i > 0 and j > 0:
result.append((i - 1, j - 1))

if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected P[i, j]")

result = np.array(result)
return result[::-1, :].T


def add_word_timestamps(
model: "Whisper",
tokenizer: Tokenizer,
Expand All @@ -34,8 +77,6 @@ def add_word_timestamps(
if len(segments) == 0:
return

from dtw import dtw

# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
Expand Down Expand Up @@ -67,12 +108,11 @@ def add_word_timestamps(
weights = (weights * qk_scale).softmax(dim=-1)

w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1)).neg().double().cpu().numpy()

alignment = dtw(matrix)
matrix = w.mean(axis=(0, 1)).neg().cpu().numpy()
text_indices, time_indices = dtw(matrix)

jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
Expand Down