Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt to fix the repetition/hallucination issue identified in #1046 #1052

Merged
merged 5 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from functools import lru_cache
from typing import Union
from typing import Optional, Union

import ffmpeg
import numpy as np
Expand All @@ -15,10 +15,8 @@
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
Expand Down Expand Up @@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:


def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Expand All @@ -113,6 +114,12 @@ def log_mel_spectrogram(
n_mels: int
The number of Mel-frequency filters, only 80 is supported

padding: int
Number of zero samples to pad to the right

device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT

Returns
-------
torch.Tensor, shape = (80, n_frames)
Expand All @@ -123,6 +130,10 @@ def log_mel_spectrogram(
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
Expand Down
42 changes: 21 additions & 21 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
Expand Down Expand Up @@ -116,7 +117,9 @@ def transcribe(
if dtype == torch.float32:
decode_options["fp16"] = False

mel = log_mel_spectrogram(audio)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES

if decode_options.get("language", None) is None:
if not model.is_multilingual:
Expand Down Expand Up @@ -212,14 +215,13 @@ def new_segment(
}

# show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1]
with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < num_frames:
while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
segment_size = min(mel_segment.shape[-1], N_FRAMES)
mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

Expand All @@ -246,20 +248,18 @@ def new_segment(
current_tokens = []

timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
0
].add_(1)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
False,
True,
]:
consecutive = consecutive.tolist() + [len(tokens)]
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))

last_slice = 0
for current_slice in consecutive:
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
Expand All @@ -278,7 +278,7 @@ def new_segment(
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice

if ended_with_single_timestamp:
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
Expand Down Expand Up @@ -329,7 +329,7 @@ def new_segment(
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
Expand All @@ -356,7 +356,7 @@ def new_segment(
)

# update progress bar
pbar.update(min(num_frames, seek) - previous_seek)
pbar.update(min(content_frames, seek) - previous_seek)

return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
Expand Down