Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

word-level timestamps in transcribe() #869

Merged
merged 26 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f9357f
word-level timestamps in `transcribe()`
jongwook Jan 20, 2023
46ea501
moving to `timing.py`
jongwook Jan 21, 2023
cfd2b81
Merge branch 'main' into word-level-timestamps
jongwook Jan 21, 2023
742d2f4
numba implementation for dtw, replacing dtw-python
jongwook Jan 22, 2023
fb12414
Merge branch 'main' into word-level-timestamps
jongwook Jan 22, 2023
80331c0
triton implementation for dtw
jongwook Jan 23, 2023
1d2ed66
add test for dtw implementations
jongwook Jan 23, 2023
b61e8f4
triton implementation of median_filter
jongwook Jan 24, 2023
54f2901
a simple word-level timestamps test
jongwook Jan 24, 2023
8ce6277
add scipy as dev dependency
jongwook Jan 24, 2023
812f446
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
cd5191f
installs an older version of Triton if CUDA < 11.4
jongwook Jan 24, 2023
f64d8bc
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
89133bd
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
d4f9399
fix broken merge
jongwook Jan 24, 2023
040aa04
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
8e2756b
loosen nvcc version match regex
jongwook Jan 25, 2023
6c431c4
find_alignment() function
jongwook Jan 25, 2023
ff6cbfd
Merge branch 'main' into word-level-timestamps
jongwook Feb 2, 2023
5fa4356
miscellaneous improvements
jongwook Feb 2, 2023
48537aa
skip median filtering when the input is too small
jongwook Feb 2, 2023
8eb29c3
Expose punctuation options in cli and transcribe() (#973)
ryanheise Feb 16, 2023
6ed4c11
Merge branch 'main' into word-level-timestamps
jongwook Mar 6, 2023
31cd418
fix merge error
jongwook Mar 6, 2023
145f325
fix merge error 2
jongwook Mar 6, 2023
2b079c4
annotating that word_timestamps is experimental
jongwook Mar 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
miscellaneous improvements
  • Loading branch information
jongwook committed Feb 2, 2023
commit 5fa43566f00a3e337f7fb481a2b962118453a96b
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import pytest


def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda")


@pytest.fixture
def random():
rand.seed(42)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
(10, 20), (32, 16), (123, 1500), (234, 189),
]
shapes = [
(4, 5, 20, 345), (6, 12, 240, 512),
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
]


Expand Down Expand Up @@ -65,7 +65,12 @@ def test_median_filter(shape):

for filter_width in [3, 5, 7, 13]:
filtered = median_filter(x, filter_width)
scipy_filtered = scipy.ndimage.median_filter(x, (1, 1, 1, filter_width), mode="nearest")

# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]

assert np.allclose(filtered, scipy_filtered)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def test_transcribe(model_name: str):
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip() == "Americans":
assert timing["start"] <= 1.75
assert timing["end"] >= 2.05
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True

assert timing_checked
22 changes: 22 additions & 0 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
}



def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
Expand Down Expand Up @@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

Expand All @@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])

if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
12 changes: 11 additions & 1 deletion whisper/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional

import gzip
import base64
import numpy as np
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -213,6 +214,15 @@ def __init__(self, dims: ModelDimensions):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2:] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
Expand Down
179 changes: 128 additions & 51 deletions whisper/timing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import List, TYPE_CHECKING
import time
import subprocess
import warnings
from dataclasses import dataclass
from typing import List, TYPE_CHECKING

import numba
import numpy as np
import torch
Expand All @@ -14,17 +18,31 @@

def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors"
if (ndim := x.ndim) <= 2: # `F.pad` does not support 1D or 2D inputs for reflect padding
x = x[None, None, :]

assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"

x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode='replicate')
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
from .triton_ops import median_filter_cuda
return median_filter_cuda(x, filter_width)
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)

if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]

# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
return x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2:
result = result[0, 0]

return result

@numba.jit
def backtrace(trace: np.ndarray):
Expand Down Expand Up @@ -108,17 +126,24 @@ def dtw_cuda(x, BLOCK_SIZE=1024):

def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
return dtw_cuda(x)
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)

return dtw_cpu(x.double().cpu().numpy())


@dataclass
class Alignment:
words: List[str]
word_tokens: List[List[int]]
start_times: np.ndarray
end_times: np.ndarray
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float


def find_alignment(
Expand All @@ -128,16 +153,14 @@ def find_alignment(
mel: torch.Tensor,
num_frames: int,
*,
max_qk_layers: int = 6,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> Alignment:
) -> List[WordTiming]:
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.timestamp_begin,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.timestamp_begin + num_frames // 2,
tokenizer.eot,
]
).to(model.device)
Expand All @@ -146,78 +169,132 @@ def find_alignment(
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]

with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()

for hook in hooks:
hook.remove()

weights = torch.cat(QKs[-max_qk_layers:]) # layers * heads * tokens * frames
weights = weights[:, :, :, : num_frames // 2]
weights = median_filter(weights, medfilt_width)
# heads * tokens * frames
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
Copy link

@saunair saunair Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why get weights only for the first num_frames/2 frames?

weights = (weights * qk_scale).softmax(dim=-1)
weights = weights / weights.norm(dim=-2, keepdim=True)
matrix = weights.mean(axis=(0, 1)).neg()
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)

matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence):-1]
text_indices, time_indices = dtw(-matrix)

text_indices, time_indices = dtw(matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))

jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
words, word_tokens = tokenizer.split_tokens_on_unicode(tokens[1:].tolist())
else:
words, word_tokens = tokenizer.split_tokens_on_spaces(tokens[1:].tolist())

word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]

# hack: ensure the first and second word is not longer than twice the median word duration.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
start_times[0] = max(0, end_times[0] - max_duration)

return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]


return Alignment(words, word_tokens, start_times, end_times)
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
# prepend it to the following word
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else:
j = i
i -= 1

# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
# append it to the previous word
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else:
i = j
j += 1


def add_word_timestamps(
*,
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
**hyperparams,
):
if len(segments) == 0:
return

text_tokens = [t for segment in segments for t in segment["tokens"]]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)

time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
alignment.start_times = time_offset + alignment.start_times
alignment.end_times = time_offset + alignment.end_times

token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
token_sources: List[int] = [None] * len(tokenizer.sot_sequence) + list(token_sources)

for segment in segments:
segment["words"] = []

word_boundaries = np.pad(np.cumsum([len(t) for t in alignment.word_tokens]), (1, 0))
for i, (word, start, end) in enumerate(zip(alignment.words, alignment.start_times, alignment.end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand
continue

segment = segments[token_sources[word_boundaries[i]]]
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
for i, timing in enumerate(alignment):
if timing.word:
segment = segments[token_sources[word_boundaries[i]]]
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict(word=timing.word, start=start, end=end, probability=timing.probability)
)

# adjust the segment-level timestamps based on the word-level timestamps
for segment in segments:
if len(segment["words"]) > 0:
segment["start"] = segment["words"][0]["start"]
segment["end"] = segment["words"][-1]["end"]
if len(words := segment["words"]) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]
Loading