Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

word-level timestamps in transcribe() #869

Merged
merged 26 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f9357f
word-level timestamps in `transcribe()`
jongwook Jan 20, 2023
46ea501
moving to `timing.py`
jongwook Jan 21, 2023
cfd2b81
Merge branch 'main' into word-level-timestamps
jongwook Jan 21, 2023
742d2f4
numba implementation for dtw, replacing dtw-python
jongwook Jan 22, 2023
fb12414
Merge branch 'main' into word-level-timestamps
jongwook Jan 22, 2023
80331c0
triton implementation for dtw
jongwook Jan 23, 2023
1d2ed66
add test for dtw implementations
jongwook Jan 23, 2023
b61e8f4
triton implementation of median_filter
jongwook Jan 24, 2023
54f2901
a simple word-level timestamps test
jongwook Jan 24, 2023
8ce6277
add scipy as dev dependency
jongwook Jan 24, 2023
812f446
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
cd5191f
installs an older version of Triton if CUDA < 11.4
jongwook Jan 24, 2023
f64d8bc
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
89133bd
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
d4f9399
fix broken merge
jongwook Jan 24, 2023
040aa04
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
8e2756b
loosen nvcc version match regex
jongwook Jan 25, 2023
6c431c4
find_alignment() function
jongwook Jan 25, 2023
ff6cbfd
Merge branch 'main' into word-level-timestamps
jongwook Feb 2, 2023
5fa4356
miscellaneous improvements
jongwook Feb 2, 2023
48537aa
skip median filtering when the input is too small
jongwook Feb 2, 2023
8eb29c3
Expose punctuation options in cli and transcribe() (#973)
ryanheise Feb 16, 2023
6ed4c11
Merge branch 'main' into word-level-timestamps
jongwook Mar 6, 2023
31cd418
fix merge error
jongwook Mar 6, 2023
145f325
fix merge error 2
jongwook Mar 6, 2023
2b079c4
annotating that word_timestamps is experimental
jongwook Mar 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
dtw-python==1.3.0
4 changes: 4 additions & 0 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Expand Down
34 changes: 34 additions & 0 deletions whisper/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import string
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -269,6 +270,39 @@ def _get_single_token_id(self, text) -> int:
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]

def split_tokens_on_unicode(self, tokens: List[int]):
words = []
word_tokens = []
current_tokens = []

for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []

return words, word_tokens

def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []

for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)

return words, word_tokens


@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
Expand Down
166 changes: 139 additions & 27 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,104 @@
import torch
import tqdm

from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt

if TYPE_CHECKING:
from .model import Whisper


def add_word_timestamps(
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
segments: List[dict],
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
):
if len(segments) == 0:
return

from dtw import dtw
from scipy.ndimage import median_filter

# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
)
for i, block in enumerate(model.decoder.blocks)
]

tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.timestamp_begin,
*[t for segment in segments for t in segment["tokens"]],
tokenizer.timestamp_begin + mel.shape[-1] // 2,
tokenizer.eot,
]
).to(model.device)

with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))

for hook in hooks:
hook.remove()

weights = torch.cat(QKs) # layers * heads * tokens * frames
weights = weights[:, :, :, : num_frames // 2].cpu()
weights = median_filter(weights, (1, 1, 1, medfilt_width))
weights = torch.tensor(weights * qk_scale).softmax(dim=-1)

w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1))

alignment = dtw(-matrix.double().numpy())

jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND

if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
split_tokens = tokenizer.split_tokens_on_unicode
else:
split_tokens = tokenizer.split_tokens_on_spaces

words, word_tokens = split_tokens(tokens[1:].tolist())

token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
token_sources = [None] * len(tokenizer.sot_sequence) + list(token_sources)

time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
start_times = time_offset + jump_times[word_boundaries[:-1]]
end_times = time_offset + jump_times[word_boundaries[1:]]

for segment in segments:
segment["words"] = []

for i, (word, start, end) in enumerate(zip(words, start_times, end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。":
Copy link
Contributor

@ryanheise ryanheise Jan 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would actually be convenient to actually re-insert the punctuation tokens so that concatenating all the words is the same as concatenating all the tokens. That would just make processing easier on the consumer end. For reference, Amazon Transcribe includes timestamped punctuation tokens in the results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice there is now a TODO comment leaning to the first approach:

if word.startswith("<|") or word.strip() in ".,!?、。":  # TODO: expand

I'm not sure if you've already committed to that approach, but I would vote for not removing the punctuation, so that whether a consumer wants to traverse the entire content by token, by word, or by segment, they can and do it in either of these 3 ways and in all the content is there (the concatenation of each result is identical). Otherwise if I consume the results by word, I would need to simultaneously look up one of the other two results to cross reference the, and look for the bits that were omitted from the sequence. Here is how Amazon Transcribe does it, for example.

On the other hand, if that is not persuasive, you might consider instead making it an option whether or not to strip out the punctuation.

(I note also that if you just left the punctuation in, the consumer would still have the ability to filter them out.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions; I've updated so that the punctuation characters are attached to the adjacent words, while keeping the word's timestamps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that looks good. I think if the prepend_punctuations and append_punctuations parameters were propagated in transcribe() and cli() that would be quite helpful, since then I could set them to empty strings to emulate the Amazon Transcribe behaviour.

continue

segment = segments[token_sources[word_boundaries[i]]]
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))

# adjust the segment-level timestamps based on the word-level timestamps
for segment in segments:
if len(segment["words"]) > 0:
segment["start"] = segment["words"][0]["start"]
segment["end"] = segment["words"][-1]["end"]


def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
Expand All @@ -27,6 +116,7 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
word_level_timestamps: bool = False,
**decode_options,
):
"""
Expand Down Expand Up @@ -90,14 +180,14 @@ def transcribe(
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")

language = decode_options["language"]
task = decode_options.get("task", "transcribe")
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)

def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
Expand Down Expand Up @@ -147,7 +237,8 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
text_tokens = [token for token in text_tokens.tolist() if token < tokenizer.eot]
text = tokenizer.decode(text_tokens)
if len(text.strip()) == 0: # skip empty text output
return

Expand All @@ -158,32 +249,28 @@ def add_segment(
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"tokens": text_tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n"
# compared to just `print(line)`, this replaces any character not representable using
# the system default encoding with an '?', avoiding UnicodeEncodeError.
sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace"))
sys.stdout.flush()

# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
previous_seek = seek

with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
segment_size = min(mel_segment.shape[-1], N_FRAMES)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)

if no_speech_threshold is not None:
Expand All @@ -194,9 +281,10 @@ def add_segment(
should_skip = False

if should_skip:
seek += segment.shape[-1] # fast-forward to the next segment boundary
seek += segment_size # fast-forward to the next segment boundary
continue

last_segment_index = len(all_segments)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
Expand All @@ -210,8 +298,8 @@ def add_segment(
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
start=time_offset + start_timestamp_position * time_precision,
end=time_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
Expand All @@ -231,22 +319,45 @@ def add_segment(
duration = last_timestamp_position * time_precision

add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
start=time_offset,
end=time_offset + duration,
text_tokens=tokens,
result=result,
)

seek += segment.shape[-1]
seek += segment_size
all_tokens.extend(tokens.tolist())

if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)

if word_level_timestamps:
current_segments = all_segments[last_segment_index:]
add_word_timestamps(
model,
tokenizer,
mel=mel_segment,
num_frames=segment_size,
segments=current_segments,
)
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
if len(word_end_timestamps) > 0:
seek_shift = (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
seek = previous_seek + round(seek_shift)

if verbose:
for segment in all_segments[last_segment_index:]:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n"
# compared to just `print(line)`, this replaces any character not representable using
# the system default encoding with an '?', avoiding UnicodeEncodeError.
sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace"))
sys.stdout.flush()

# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
pbar.update(min(num_frames, seek) - previous_seek)
previous_seek = seek

return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)

Expand Down Expand Up @@ -280,6 +391,7 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_level_timestamps", type=str2bool, default=False, help="Extract word-level timestamps and refine the results based on them")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")

args = parser.parse_args().__dict__
Expand Down