From 0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f Mon Sep 17 00:00:00 2001 From: Jibin Mathew Date: Fri, 30 Sep 2022 14:45:51 -0700 Subject: [PATCH] Add model_dir to arguments (#202) * Add model_dir to arguments * minor formatting change Co-authored-by: Jong Wook Kim --- whisper/transcribe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index f97029989..7b6105b1a 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -249,6 +249,7 @@ def cli(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") + parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") @@ -274,6 +275,7 @@ def cli(): args = parser.parse_args().__dict__ model_name: str = args.pop("model") + model_dir: str = args.pop("model_dir") output_dir: str = args.pop("output_dir") device: str = args.pop("device") os.makedirs(output_dir, exist_ok=True) @@ -290,7 +292,7 @@ def cli(): temperature = [temperature] from . import load_model - model = load_model(model_name, device=device) + model = load_model(model_name, device=device, download_root=model_dir) for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args)