Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

word-level timestamps in transcribe() #869

Merged
merged 26 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f9357f
word-level timestamps in `transcribe()`
jongwook Jan 20, 2023
46ea501
moving to `timing.py`
jongwook Jan 21, 2023
cfd2b81
Merge branch 'main' into word-level-timestamps
jongwook Jan 21, 2023
742d2f4
numba implementation for dtw, replacing dtw-python
jongwook Jan 22, 2023
fb12414
Merge branch 'main' into word-level-timestamps
jongwook Jan 22, 2023
80331c0
triton implementation for dtw
jongwook Jan 23, 2023
1d2ed66
add test for dtw implementations
jongwook Jan 23, 2023
b61e8f4
triton implementation of median_filter
jongwook Jan 24, 2023
54f2901
a simple word-level timestamps test
jongwook Jan 24, 2023
8ce6277
add scipy as dev dependency
jongwook Jan 24, 2023
812f446
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
cd5191f
installs an older version of Triton if CUDA < 11.4
jongwook Jan 24, 2023
f64d8bc
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
89133bd
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
d4f9399
fix broken merge
jongwook Jan 24, 2023
040aa04
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
8e2756b
loosen nvcc version match regex
jongwook Jan 25, 2023
6c431c4
find_alignment() function
jongwook Jan 25, 2023
ff6cbfd
Merge branch 'main' into word-level-timestamps
jongwook Feb 2, 2023
5fa4356
miscellaneous improvements
jongwook Feb 2, 2023
48537aa
skip median filtering when the input is too small
jongwook Feb 2, 2023
8eb29c3
Expose punctuation options in cli and transcribe() (#973)
ryanheise Feb 16, 2023
6ed4c11
Merge branch 'main' into word-level-timestamps
jongwook Mar 6, 2023
31cd418
fix merge error
jongwook Mar 6, 2023
145f325
fix merge error 2
jongwook Mar 6, 2023
2b079c4
annotating that word_timestamps is experimental
jongwook Mar 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
triton implementation for dtw
  • Loading branch information
jongwook committed Jan 23, 2023
commit 80331c0c67535193ce7b4d6907c726b6456b66c1
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

import pkg_resources
from setuptools import setup, find_packages
Expand All @@ -9,6 +10,10 @@ def read_version(fname="whisper/version.py"):
return locals()["__version__"]


requirements = []
if sys.platform.startswith("linux"):
requirements.append("triton>=2.0.0.dev20221202")

setup(
name="openai-whisper",
py_modules=["whisper"],
Expand All @@ -22,7 +27,7 @@ def read_version(fname="whisper/version.py"):
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
install_requires=requirements + [
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
Expand Down
87 changes: 67 additions & 20 deletions whisper/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,37 @@ def median_filter(x: torch.Tensor, filter_width: int):
return slices.median(dim=-1).values


@numba.jit
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1

result = []
while i > 0 or j > 0:
result.append((i - 1, j - 1))

if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected trace[i, j]")

result = np.array(result)
return result[::-1, :].T


@numba.jit(nopython=True, parallel=True)
def dtw(x: np.ndarray):
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)

i, j = 0, 0
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
Expand All @@ -46,22 +70,45 @@ def dtw(x: np.ndarray):
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t

result = []
while i > 0 and j > 0:
result.append((i - 1, j - 1))
return backtrace(trace)

if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected P[i, j]")

result = np.array(result)
return result[::-1, :].T
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel

M, N = x.shape
# assert M < N, f"{M=} should be smaller than {N=}"
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"

x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)

dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE
)

trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]

return backtrace(trace.cpu().numpy())


def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
return dtw_cuda(x)

return dtw_cpu(x.double().cpu().numpy())


def add_word_timestamps(
Expand Down Expand Up @@ -102,13 +149,13 @@ def add_word_timestamps(
for hook in hooks:
hook.remove()

weights = torch.cat(QKs) # layers * heads * tokens * frames
weights = torch.cat(QKs[-6:]) # layers * heads * tokens * frames

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only considering (at most) the last 6 layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was because the attention weights in the later layers were more indicative of the time alignment. I've updated this part, and now it uses a mask to select which layers and heads to find the alignment.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok interesting. I have to check this masking trick.
I don't understand why later layers are more indicative. Is it an intuition that I am missing, or some empirical results you got from experiments?

weights = weights[:, :, :, : num_frames // 2]
weights = median_filter(weights, medfilt_width)
weights = (weights * qk_scale).softmax(dim=-1)
weights = weights / weights.norm(dim=-2, keepdim=True)
matrix = weights.mean(axis=(0, 1)).neg()

w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1)).neg().cpu().numpy()
text_indices, time_indices = dtw(matrix)

jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
Expand Down Expand Up @@ -136,7 +183,7 @@ def add_word_timestamps(
segment["words"] = []

for i, (word, start, end) in enumerate(zip(words, start_times, end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。":
if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand
continue

segment = segments[token_sources[word_boundaries[i]]]
Expand Down
33 changes: 33 additions & 0 deletions whisper/triton_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
try:
import triton
import triton.language as tl
except ImportError:
raise RuntimeError("triton import failed; try `pip install triton`")


@triton.jit
def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton jit requires gcc/clang, python-dev, and cuda-dev at runtime. Please consider some lighter-weight alternatives.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up; I made it fall back to pytorch and numba implementations if triton fails with RuntimeError or subprocess.CalledProcessError. I haven't tested this on a non-dev environment, so please feel free to ping me if the fallback does not work for any reason.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Thank you.

offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < M

for k in range(1, N + M + 1): # k = i + j
tl.debug_barrier()

p0 = cost + (k - 1) * cost_stride
p1 = cost + k * cost_stride
p2 = cost + k * cost_stride + 1

c0 = tl.load(p0 + offsets, mask=mask)
c1 = tl.load(p1 + offsets, mask=mask)
c2 = tl.load(p2 + offsets, mask=mask)

x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)

cost_ptr = cost + (k + 1) * cost_stride + 1
tl.store(cost_ptr + offsets, cost_row, mask=mask)

trace_ptr = trace + (k + 1) * trace_stride + 1
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))