diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d00..de86ce04 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -17,7 +17,7 @@ @torch.no_grad() def detect_language( - model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None + model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, language_bias: Optional[dict[str,float]] = None ) -> Tuple[Tensor, List[dict]]: """ Detect the spoken language in the audio, and return them as list of strings, along with the ids @@ -56,6 +56,14 @@ def detect_language( x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] logits = model.logits(x, mel)[:, 0] + # apply language_bias to logits + if language_bias: + biases = torch.zeros(logits.size(1), device=logits.device) + for lang, bias in language_bias.items(): + token = tokenizer.to_language_token(lang) + biases[token] = bias + logits += biases + # collect detected languages; suppress all non-language tokens mask = torch.ones(logits.shape[-1], dtype=torch.bool) mask[list(tokenizer.all_language_tokens)] = False diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a20..5e7a946f 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -51,6 +51,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + language_bias: Optional[dict[str,float]] = None, **decode_options, ): """ @@ -113,6 +114,10 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + language_bias: Optional[dict[str,float]] = None + A dictionary of language codes to positive or negative float values. These values will be + applied to the language detection logits before choosing the language. + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -143,7 +148,7 @@ def transcribe( "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" ) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(mel_segment) + _, probs = model.detect_language(mel_segment, language_bias=language_bias) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: print(