Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add language_bias parameter to detect_language #2004

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@torch.no_grad()
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, language_bias: Optional[dict[str,float]] = None
) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
Expand Down Expand Up @@ -56,6 +56,14 @@ def detect_language(
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]

# apply language_bias to logits
if language_bias:
biases = torch.zeros(logits.size(1), device=logits.device)
for lang, bias in language_bias.items():
token = tokenizer.to_language_token(lang)
biases[token] = bias
logits += biases

# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
Expand Down
7 changes: 6 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def transcribe(
append_punctuations: str = "\"'.。,,!!??::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
language_bias: Optional[dict[str,float]] = None,
**decode_options,
):
"""
Expand Down Expand Up @@ -113,6 +114,10 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected

language_bias: Optional[dict[str,float]] = None
A dictionary of language codes to positive or negative float values. These values will be
applied to the language detection logits before choosing the language.

Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
Expand Down Expand Up @@ -143,7 +148,7 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
_, probs = model.detect_language(mel_segment, language_bias=language_bias)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
Expand Down