Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model_dir to arguments #202

Merged
merged 2 commits into from
Sep 30, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add model_dir to arguments
  • Loading branch information
jibinmathew69 committed Sep 30, 2022
commit 9247fb9e983731320e84187b876e198d748850cc
5 changes: 4 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def cli():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type="str", default=None, help="overide default model directory with new directory")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
Expand All @@ -272,8 +273,10 @@ def cli():
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")


args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir:str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
Expand All @@ -290,7 +293,7 @@ def cli():
temperature = [temperature]

from . import load_model
model = load_model(model_name, device=device)
model = load_model(model_name, device=device, download_root=model_dir)

for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
Expand Down