Skip to content

Commit

Permalink
word-level timestamps in transcribe()
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Jan 20, 2023
1 parent 12e1089 commit 8f9357f
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 27 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
dtw-python==1.3.0
4 changes: 4 additions & 0 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Expand Down
34 changes: 34 additions & 0 deletions whisper/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import string
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -269,6 +270,39 @@ def _get_single_token_id(self, text) -> int:
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]

def split_tokens_on_unicode(self, tokens: List[int]):
words = []
word_tokens = []
current_tokens = []

for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []

return words, word_tokens

def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []

for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)

return words, word_tokens


@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
Expand Down
166 changes: 139 additions & 27 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,104 @@
import torch
import tqdm

from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt

if TYPE_CHECKING:
from .model import Whisper


def add_word_timestamps(
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
segments: List[dict],
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
):
if len(segments) == 0:
return

from dtw import dtw
from scipy.ndimage import median_filter

# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
)
for i, block in enumerate(model.decoder.blocks)
]

tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.timestamp_begin,
*[t for segment in segments for t in segment["tokens"]],
tokenizer.timestamp_begin + mel.shape[-1] // 2,
tokenizer.eot,
]
).to(model.device)

with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))

for hook in hooks:
hook.remove()

weights = torch.cat(QKs) # layers * heads * tokens * frames
weights = weights[:, :, :, : num_frames // 2].cpu()
weights = median_filter(weights, (1, 1, 1, medfilt_width))
weights = torch.tensor(weights * qk_scale).softmax(dim=-1)

w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1))

alignment = dtw(-matrix.double().numpy())

jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
split_tokens = tokenizer.split_tokens_on_unicode
else:
split_tokens = tokenizer.split_tokens_on_spaces

words, word_tokens = split_tokens(tokens[1:].tolist())

token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources)

time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
start_times = time_offset + jump_times[word_boundaries[:-1]]
end_times = time_offset + jump_times[word_boundaries[1:]]

for segment in segments:
segment["words"] = []

for i, (word, start, end) in enumerate(zip(words, start_times, end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。":
continue

segment = segments[token_sources[word_boundaries[i]]]
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))

# adjust the segment-level timestamps based on the word-level timestamps
for segment in segments:
if len(segment["words"]) > 0:
segment["start"] = segment["words"][0]["start"]
segment["end"] = segment["words"][-1]["end"]


def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
Expand All @@ -27,6 +116,7 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
word_level_timestamps: bool = False,
**decode_options,
):
"""
Expand Down Expand Up @@ -90,14 +180,14 @@ def transcribe(
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")

language = decode_options["language"]
task = decode_options.get("task", "transcribe")
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)

def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
Expand Down Expand Up @@ -147,7 +237,8 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
text_tokens = [token for token in text_tokens.tolist() if token < tokenizer.eot]
text = tokenizer.decode(text_tokens)
if len(text.strip()) == 0: # skip empty text output
return

Expand All @@ -158,32 +249,28 @@ def add_segment(
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"tokens": text_tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n"
# compared to just `print(line)`, this replaces any character not representable using
# the system default encoding with an '?', avoiding UnicodeEncodeError.
sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace"))
sys.stdout.flush()

# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
previous_seek = seek

with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
segment_size = min(mel_segment.shape[-1], N_FRAMES)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)

if no_speech_threshold is not None:
Expand All @@ -194,9 +281,10 @@ def add_segment(
should_skip = False

if should_skip:
seek += segment.shape[-1] # fast-forward to the next segment boundary
seek += segment_size # fast-forward to the next segment boundary
continue

last_segment_index = len(all_segments)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
Expand All @@ -210,8 +298,8 @@ def add_segment(
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
start=time_offset + start_timestamp_position * time_precision,
end=time_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
Expand All @@ -231,22 +319,45 @@ def add_segment(
duration = last_timestamp_position * time_precision

add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
start=time_offset,
end=time_offset + duration,
text_tokens=tokens,
result=result,
)

seek += segment.shape[-1]
seek += segment_size
all_tokens.extend(tokens.tolist())

if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)

if word_level_timestamps:
current_segments = all_segments[last_segment_index:]
add_word_timestamps(
model,
tokenizer,
mel=mel_segment,
num_frames=segment_size,
segments=current_segments,
)
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
if len(word_end_timestamps) > 0:
seek_shift = (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
seek = previous_seek + round(seek_shift)

if verbose:
for segment in all_segments[last_segment_index:]:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n"
# compared to just `print(line)`, this replaces any character not representable using
# the system default encoding with an '?', avoiding UnicodeEncodeError.
sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace"))
sys.stdout.flush()

# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
pbar.update(min(num_frames, seek) - previous_seek)
previous_seek = seek

return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)

Expand Down Expand Up @@ -280,6 +391,7 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_level_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")

args = parser.parse_args().__dict__
Expand Down

0 comments on commit 8f9357f

Please sign in to comment.