Skip to content

Commit

Permalink
miscellaneous improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Feb 2, 2023
1 parent ff6cbfd commit 5fa4356
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 116 deletions.
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import pytest


def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda")


@pytest.fixture
def random():
rand.seed(42)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
(10, 20), (32, 16), (123, 1500), (234, 189),
]
shapes = [
(4, 5, 20, 345), (6, 12, 240, 512),
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
]


Expand Down Expand Up @@ -65,7 +65,12 @@ def test_median_filter(shape):

for filter_width in [3, 5, 7, 13]:
filtered = median_filter(x, filter_width)
scipy_filtered = scipy.ndimage.median_filter(x, (1, 1, 1, filter_width), mode="nearest")

# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]

assert np.allclose(filtered, scipy_filtered)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def test_transcribe(model_name: str):
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip() == "Americans":
assert timing["start"] <= 1.75
assert timing["end"] >= 2.05
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True

assert timing_checked
22 changes: 22 additions & 0 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
}



def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
Expand Down Expand Up @@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

Expand All @@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])

if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
12 changes: 11 additions & 1 deletion whisper/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional

import gzip
import base64
import numpy as np
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -213,6 +214,15 @@ def __init__(self, dims: ModelDimensions):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2:] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
Expand Down
179 changes: 128 additions & 51 deletions whisper/timing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import List, TYPE_CHECKING
import time
import subprocess
import warnings
from dataclasses import dataclass
from typing import List, TYPE_CHECKING

import numba
import numpy as np
import torch
Expand All @@ -14,17 +18,31 @@

def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors"
if (ndim := x.ndim) <= 2: # `F.pad` does not support 1D or 2D inputs for reflect padding
x = x[None, None, :]

assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"

x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode='replicate')
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
from .triton_ops import median_filter_cuda
return median_filter_cuda(x, filter_width)
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)

if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]

# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
return x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2:
result = result[0, 0]

return result

@numba.jit
def backtrace(trace: np.ndarray):
Expand Down Expand Up @@ -108,17 +126,24 @@ def dtw_cuda(x, BLOCK_SIZE=1024):

def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
return dtw_cuda(x)
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)

return dtw_cpu(x.double().cpu().numpy())


@dataclass
class Alignment:
words: List[str]
word_tokens: List[List[int]]
start_times: np.ndarray
end_times: np.ndarray
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float


def find_alignment(
Expand All @@ -128,16 +153,14 @@ def find_alignment(
mel: torch.Tensor,
num_frames: int,
*,
max_qk_layers: int = 6,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> Alignment:
) -> List[WordTiming]:
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.timestamp_begin,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.timestamp_begin + num_frames // 2,
tokenizer.eot,
]
).to(model.device)
Expand All @@ -146,78 +169,132 @@ def find_alignment(
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]

with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()

for hook in hooks:
hook.remove()

weights = torch.cat(QKs[-max_qk_layers:]) # layers * heads * tokens * frames
weights = weights[:, :, :, : num_frames // 2]
weights = median_filter(weights, medfilt_width)
# heads * tokens * frames
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
weights = weights / weights.norm(dim=-2, keepdim=True)
matrix = weights.mean(axis=(0, 1)).neg()
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)

matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence):-1]
text_indices, time_indices = dtw(-matrix)

text_indices, time_indices = dtw(matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))

jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
words, word_tokens = tokenizer.split_tokens_on_unicode(tokens[1:].tolist())
else:
words, word_tokens = tokenizer.split_tokens_on_spaces(tokens[1:].tolist())

word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]

# hack: ensure the first and second word is not longer than twice the median word duration.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
start_times[0] = max(0, end_times[0] - max_duration)

return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]


return Alignment(words, word_tokens, start_times, end_times)
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
# prepend it to the following word
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else:
j = i
i -= 1

# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
# append it to the previous word
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else:
i = j
j += 1


def add_word_timestamps(
*,
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
**hyperparams,
):
if len(segments) == 0:
return

text_tokens = [t for segment in segments for t in segment["tokens"]]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)

time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
alignment.start_times = time_offset + alignment.start_times
alignment.end_times = time_offset + alignment.end_times

token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
token_sources: List[int] = [None] * len(tokenizer.sot_sequence) + list(token_sources)

for segment in segments:
segment["words"] = []

word_boundaries = np.pad(np.cumsum([len(t) for t in alignment.word_tokens]), (1, 0))
for i, (word, start, end) in enumerate(zip(alignment.words, alignment.start_times, alignment.end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand
continue

segment = segments[token_sources[word_boundaries[i]]]
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
for i, timing in enumerate(alignment):
if timing.word:
segment = segments[token_sources[word_boundaries[i]]]
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict(word=timing.word, start=start, end=end, probability=timing.probability)
)

# adjust the segment-level timestamps based on the word-level timestamps
for segment in segments:
if len(segment["words"]) > 0:
segment["start"] = segment["words"][0]["start"]
segment["end"] = segment["words"][-1]["end"]
if len(words := segment["words"]) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]
Loading

0 comments on commit 5fa4356

Please sign in to comment.