diff --git a/whisper/decoding.py b/whisper/decoding.py index 891aaa47..eaedf70d 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -78,7 +78,7 @@ class DecodingOptions: sample_len: Optional[int] = None # maximum number of tokens to sample best_of: Optional[int] = None # number of independent samples to collect, when t > 0 beam_size: Optional[int] = None # number of beams in beam search, when t == 0 - patience: float = 0.0 # patience in beam search (https://arxiv.org/abs/2204.05424) + patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) # options for ranking generations (either beams or best-of-N samples) length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm @@ -275,14 +275,16 @@ def finalize(self, tokens: Tensor, sum_logprobs: Tensor): class BeamSearchDecoder(TokenDecoder): - def __init__(self, beam_size: int, eot: int, inference: Inference, patience: float = 0.0): + def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): self.beam_size = beam_size self.eot = eot self.inference = inference - self.patience = patience - self.max_candidates: int = round(beam_size * (1.0 + patience)) + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) self.finished_sequences = None + assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" + def reset(self): self.finished_sequences = None @@ -496,8 +498,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions: if options.temperature == 0: if options.best_of is not None: raise ValueError("best_of with greedy sampling (T=0) is not compatible") - if options.patience != 0.0 and options.beam_size is None: - raise ValueError("nonzero patience requires beam_size to be given") + if options.patience is not None and options.beam_size is None: + raise ValueError("patience requires beam_size to be given") if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): raise ValueError("length_penalty (alpha) should be a value between 0 and 1") diff --git a/whisper/transcribe.py b/whisper/transcribe.py index f915f79f..195ea2ee 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -263,7 +263,7 @@ def cli(): parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") - parser.add_argument("--patience", type=float, default=0.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (0.0) is equivalent to not using patience") + parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")