Skip to content

Commit

Permalink
numba implementation for dtw, replacing dtw-python
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Jan 22, 2023
1 parent cfd2b81 commit e5e45ec
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numba
numpy
torch
tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
dtw-python==1.3.0
54 changes: 47 additions & 7 deletions whisper/timing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, TYPE_CHECKING

import numba
import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -21,6 +22,48 @@ def median_filter(x: torch.Tensor, filter_width: int):
return slices.median(dim=-1).values


@numba.jit(nopython=True, parallel=True)
def dtw(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)

i, j = 0, 0
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
r0 = cost[i - 1, j - 1]
r1 = cost[i - 1, j]
r2 = cost[i, j - 1]

if r0 < r1 and r0 < r2:
r, p = r0, 0
elif r1 < r0 and r1 < r2:
r, p = r1, 1
else:
r, p = r2, 2

cost[i, j] = x[i - 1, j - 1] + r
trace[i, j] = p

result = []
while i > 0 and j > 0:
result.append((i - 1, j - 1))

if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected P[i, j]")

result = np.array(result)
return result[::-1, :].T


def add_word_timestamps(
model: "Whisper",
tokenizer: Tokenizer,
Expand All @@ -34,8 +77,6 @@ def add_word_timestamps(
if len(segments) == 0:
return

from dtw import dtw

# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
Expand Down Expand Up @@ -67,12 +108,11 @@ def add_word_timestamps(
weights = (weights * qk_scale).softmax(dim=-1)

w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1)).neg().double().cpu().numpy()

alignment = dtw(matrix)
matrix = w.mean(axis=(0, 1)).neg().cpu().numpy()
text_indices, time_indices = dtw(matrix)

jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
Expand Down

0 comments on commit e5e45ec

Please sign in to comment.