Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
MiscellaneousStuff committed Oct 27, 2022
2 parents 0758d43 + 9f70a35 commit e87e7c4
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 67 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ choco install ffmpeg
scoop install ffmpeg
```

You may need [`rust`](https://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment.
You may need [`rust`](https://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:

```bash
pip install setuptools-rust
```


## Available models and languages
Expand Down Expand Up @@ -125,6 +129,11 @@ result = whisper.decode(model, mel, options)
print(result.text)
```

## More examples

Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc.


## License

The code and the model weights of Whisper are released under the MIT License. See [LICENSE](LICENSE) for further details.
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
name="whisper",
py_modules=["whisper"],
version="1.0",
description="",
description="Robust Speech Recognition via Large-Scale Weak Supervision",
readme="README.md",
python_requires=">=3.7",
author="OpenAI",
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
Expand Down
5 changes: 4 additions & 1 deletion whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
download_root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
Expand Down
4 changes: 2 additions & 2 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
Expand All @@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length))
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
Expand Down
18 changes: 10 additions & 8 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class DecodingOptions:
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: float = 0.0 # patience in beam search (https://arxiv.org/abs/2204.05424)
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)

# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
Expand All @@ -94,7 +94,7 @@ class DecodingOptions:

# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 0.0 # the initial timestamp cannot be later than this
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this

# implementation details
fp16: bool = True # use fp16 for most of the calculation
Expand Down Expand Up @@ -275,14 +275,16 @@ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):


class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: float = 0.0):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience
self.max_candidates: int = round(beam_size * (1.0 + patience))
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None

assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"

def reset(self):
self.finished_sequences = None

Expand Down Expand Up @@ -496,8 +498,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience != 0.0 and options.beam_size is None:
raise ValueError("nonzero patience requires beam_size to be given")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")

Expand Down Expand Up @@ -613,7 +615,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
n_audio: int = mel.shape[0]

audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).expand(n_audio, -1)
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)

# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
Expand Down
6 changes: 3 additions & 3 deletions whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ def forward(
):
q = self.query(x)

if kv_cache is None or xa is None:
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa))
v = kv_cache.get(self.value, self.value(xa))
k = kv_cache[self.key]
v = kv_cache[self.value]

wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)
Expand Down
97 changes: 52 additions & 45 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,49 +84,48 @@ def transcribe(
mel = log_mel_spectrogram(audio)

if decode_options.get("language", None) is None:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")

mel = mel.unsqueeze(0)
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")

language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)

def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]:
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
kwargs = {**decode_options}
t = temperatures[0]
if t == 0:
best_of = kwargs.pop("best_of", None)
else:
best_of = kwargs.get("best_of", None)

options = DecodingOptions(**kwargs, temperature=t)
results = model.decode(segment, options)

kwargs.pop("beam_size", None) # no beam search for t > 0
kwargs.pop("patience", None) # no patience for t > 0
kwargs["best_of"] = best_of # enable best_of for t > 0
for t in temperatures[1:]:
needs_fallback = [
compression_ratio_threshold is not None
and result.compression_ratio > compression_ratio_threshold
or logprob_threshold is not None
and result.avg_logprob < logprob_threshold
for result in results
]
if any(needs_fallback):
options = DecodingOptions(**kwargs, temperature=t)
retries = model.decode(segment[needs_fallback], options)
for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
results[original_index] = retries[retry_index]

return results
decode_result = None

for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)

options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)

needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
needs_fallback = True # average log probability is too low

if not needs_fallback:
break

return decode_result

seek = 0
input_stride = exact_div(
Expand Down Expand Up @@ -175,11 +174,11 @@ def add_segment(
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE

decode_options["prompt"] = all_tokens[prompt_reset_since:]
result = decode_with_fallback(segment)[0]
result: DecodingResult = decode_with_fallback(segment)
tokens = torch.tensor(result.tokens)

if no_speech_threshold is not None:
Expand Down Expand Up @@ -220,7 +219,7 @@ def add_segment(
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0:
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
Expand Down Expand Up @@ -253,6 +252,7 @@ def cli():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
Expand All @@ -263,8 +263,8 @@ def cli():
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=0.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (0.0) is equivalent to not using patience")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")

parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
Expand All @@ -275,15 +275,18 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")

args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)

if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"

temperature = args.pop("temperature")
Expand All @@ -293,8 +296,12 @@ def cli():
else:
temperature = [temperature]

threads = args.pop("threads")
if threads > 0:
torch.set_num_threads(threads)

from . import load_model
model = load_model(model_name, device=device)
model = load_model(model_name, device=device, download_root=model_dir)

for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
Expand Down
12 changes: 6 additions & 6 deletions whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def compression_ratio(text) -> float:
return len(text) / len(zlib.compress(text.encode("utf-8")))


def format_timestamp(seconds: float, always_include_hours: bool = False):
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)

Expand All @@ -40,8 +40,8 @@ def format_timestamp(seconds: float, always_include_hours: bool = False):
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000

hours_marker = f"{hours}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"


def write_txt(transcript: Iterator[dict], file: TextIO):
Expand All @@ -54,7 +54,7 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].replace('-->', '->')}\n",
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
Expand All @@ -79,8 +79,8 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
Expand Down

0 comments on commit e87e7c4

Please sign in to comment.