Skip to content

Commit

Permalink
drop python 3.7 support (#889)
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Jan 24, 2023
1 parent 55f690a commit a6b36ed
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 49 deletions.
19 changes: 6 additions & 13 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,10 @@ def __init__(self, temperature: float, eot: int):
self.eot = eot

def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
next_tokens = Categorical(logits=logits / self.temperature).sample()

logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
Expand Down Expand Up @@ -511,10 +510,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:

def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt

if prefix:
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
Expand All @@ -523,7 +520,7 @@ def _get_initial_tokens(self) -> Tuple[int]:
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens

if prompt:
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
Expand Down Expand Up @@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)

result = DecodingTask(model, options).run(mel)

if single:
result = result[0]

return result
return result[0] if single else result
38 changes: 13 additions & 25 deletions whisper/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from functools import lru_cache
from functools import lru_cache, cached_property
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -156,43 +156,35 @@ def decode_with_timestamps(self, tokens) -> str:
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)

@property
@lru_cache()
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id

@property
@lru_cache()
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")

@property
@lru_cache()
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")

@property
@lru_cache()
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")

@property
@lru_cache()
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")

@property
@lru_cache()
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")

@property
@lru_cache()
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1

@property
@lru_cache()
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
Expand All @@ -210,8 +202,7 @@ def language_token(self) -> int:

raise KeyError(f"Language {self.language} not found in tokenizer.")

@property
@lru_cache()
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
Expand All @@ -222,18 +213,15 @@ def all_language_tokens(self) -> Tuple[int]:
result.append(token_id)
return tuple(result)

@property
@lru_cache()
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)

@property
@lru_cache()
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])

@property
@lru_cache()
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
Expand Down
25 changes: 14 additions & 11 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
**decode_options,
):
"""
Expand Down Expand Up @@ -138,10 +139,11 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
all_segments = []
prompt_reset_since = 0

initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt)
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []

def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
Expand Down Expand Up @@ -243,7 +245,11 @@ def add_segment(
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek

return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
segments=all_segments,
language=language
)


def cli():
Expand Down Expand Up @@ -292,21 +298,18 @@ def cli():
args["language"] = "en"

temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]

threads = args.pop("threads")
if threads > 0:
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)

from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)

writer = get_writer(output_format, output_dir)

for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
Expand Down

0 comments on commit a6b36ed

Please sign in to comment.