From c4b50c0824b780a843a9043899af0c62385dbe65 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Wed, 8 Mar 2023 18:46:38 -0500 Subject: [PATCH] kwargs in decode() for convenience (#1061) * kwargs in decode() for convenience * formatting fix --- whisper/decoding.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/whisper/decoding.py b/whisper/decoding.py index ff9261e0..81cd8452 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -778,7 +778,10 @@ def run(self, mel: Tensor) -> List[DecodingResult]: @torch.no_grad() def decode( - model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions() + model: "Whisper", + mel: Tensor, + options: DecodingOptions = DecodingOptions(), + **kwargs, ) -> Union[DecodingResult, List[DecodingResult]]: """ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). @@ -802,6 +805,9 @@ def decode( if single := mel.ndim == 2: mel = mel.unsqueeze(0) + if kwargs: + options = replace(options, **kwargs) + result = DecodingTask(model, options).run(mel) return result[0] if single else result