Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added --output option #333

Merged
merged 5 commits into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
refactor output format handling
  • Loading branch information
jongwook committed Jan 22, 2023
commit dfea59cba7b0406a05fbc15b2e5f80e7d8cd2f80
29 changes: 7 additions & 22 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, get_writer

if TYPE_CHECKING:
from .model import Whisper
Expand Down Expand Up @@ -260,9 +260,9 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")

parser.add_argument("--output", type=str, default="all", choices=["none", "txt", "vtt", "srt", "all"], help="output files to generate, all(generates txt, vtt and srt), txt(generates only txt), vtt(generates txt and vtt), srt(generates txt and srt) ")

parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")

Expand All @@ -283,13 +283,12 @@ def cli():
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")


args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
output_files: str = args.pop("output")
os.makedirs(output_dir, exist_ok=True)

if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
Expand All @@ -311,25 +310,11 @@ def cli():
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)

writer = get_writer(output_format, output_dir)

for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)

audio_basename = os.path.basename(audio_path)

# save TXT
if output_files != "none":
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)

# save VTT
if output_files in ["vtt","all"]:
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)

# save SRT
if output_files in ["srt","all"]:
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
writer(result, audio_path)


if __name__ == '__main__':
Expand Down
125 changes: 83 additions & 42 deletions whisper/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import zlib
from typing import Iterator, TextIO
from typing import Callable, TextIO


def exact_div(x, y):
Expand Down Expand Up @@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"


def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)


def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)


def write_srt(transcript: Iterator[dict], file: TextIO):
"""
Write a transcript to a file in SRT format.
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class ResultWriter:
extension: str

def __init__(self, output_dir: str):
self.output_dir = output_dir

def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)

with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)

def write_result(self, result: dict, file: TextIO):
raise NotImplementedError


class WriteTXT(ResultWriter):
extension: str = "txt"

def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True)


class WriteVTT(ResultWriter):
extension: str = "vtt"

def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
for segment in result["segments"]:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)


class WriteSRT(ResultWriter):
extension: str = "srt"

def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)


class WriteJSON(ResultWriter):
extension: str = "json"

def write_result(self, result: dict, file: TextIO):
json.dump(result, file)


def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"json": WriteJSON,
}

if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]

def write_all(result: dict, file: TextIO):
for writer in all_writers:
writer(result, file)

return write_all

return writers[output_format](output_dir)