Skip to content

Commit

Permalink
patience definition to match the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Sep 28, 2022
1 parent b4308c4 commit 62fe7f1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 8 additions & 6 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class DecodingOptions:
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: float = 0.0 # patience in beam search (https://arxiv.org/abs/2204.05424)
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)

# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
Expand Down Expand Up @@ -275,14 +275,16 @@ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):


class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: float = 0.0):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience
self.max_candidates: int = round(beam_size * (1.0 + patience))
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None

assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"

def reset(self):
self.finished_sequences = None

Expand Down Expand Up @@ -496,8 +498,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience != 0.0 and options.beam_size is None:
raise ValueError("nonzero patience requires beam_size to be given")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")

Expand Down
2 changes: 1 addition & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def cli():
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=0.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (0.0) is equivalent to not using patience")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")

parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
Expand Down

0 comments on commit 62fe7f1

Please sign in to comment.