Skip to content

Commit

Permalink
Add whispercpp support
Browse files Browse the repository at this point in the history
  • Loading branch information
HHousen committed Jul 20, 2023
1 parent 8542980 commit 58d8640
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
7 changes: 5 additions & 2 deletions lecture2notes/dataset/transcripts_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from timeit import default_timer as timer

import jiwer
import youtube_dl
import yt_dlp as youtube_dl
from tqdm import tqdm

from ..end_to_end.transcribe import transcribe_main as transcribe
Expand Down Expand Up @@ -108,6 +108,9 @@
elif ARGS.method == "vosk":
model = transcribe.load_vosk_model(ARGS.model_dir)
desired_sample_rate = 16000
elif ARGS.method == "whispercpp":
model = transcribe.load_whispercpp_model(ARGS.model_dir)
desired_sample_rate = 16000

for transcript in tqdm(transcripts, desc="Transcribing"):
video_id = transcript.split(".")[0]
Expand Down Expand Up @@ -192,7 +195,7 @@
transcript_path = str(transcript_ml_path)[: -(4 + len(ARGS.suffix))] + ".vtt"
transcript_ground_truth = transcribe.caption_file_to_string(
transcript_path, remove_speakers=True
)
)[0]

with open(transcript_ml_path, "r") as file:
transcript_prediction = file.read()
Expand Down
60 changes: 59 additions & 1 deletion lecture2notes/end_to_end/transcribe/transcribe_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import glob
import json
import logging
import ffmpeg
import os
import wave
from pathlib import Path
Expand Down Expand Up @@ -34,7 +35,7 @@ def extract_audio(video_path, output_path):
+ str(output_path)
)
command = (
"ffmpeg -i " + str(video_path) + " -f wav -ab 192000 -vn " + str(output_path)
"ffmpeg -y -i " + str(video_path) + " -f wav -ab 192000 -vn " + str(output_path)
)
os.system(command)
return output_path
Expand All @@ -58,6 +59,8 @@ def transcribe_audio(audio_path, method="sphinx", **kwargs):
return transcribe_audio_deepspeech(audio_path, **kwargs)
if method == "wav2vec":
return transcribe_audio_wav2vec(audio_path, **kwargs)
if method == "whispercpp":
return transcribe_audio_whispercpp(audio_path, **kwargs)
return transcribe_audio_generic(audio_path, method, **kwargs), None


Expand Down Expand Up @@ -231,6 +234,59 @@ def transcribe_audio_wav2vec(
return " ".join(final_transcript).strip(), None


def transcribe_with_time(
self, data, num_proc: int = 1, strict: bool = False
):
if strict:
assert (
self.context.is_initialized
), "strict=True and context is not initialized. Make sure to call 'context.init_state()' before."
else:
if not self.context.is_initialized and not self._context_initialized:
self.context.init_state()
self._context_initialized = True

self.context.full_parallel(self.params, data, num_proc)
return [
{
"start": self.context.full_get_segment_start(i)/100,
"end": self.context.full_get_segment_end(i)/100,
"word": self.context.full_get_segment_text(i),
}
for i in range(self.context.full_n_segments())
]

def load_whispercpp_model(model_name_or_path="small.en"):
from whispercpp import Whisper
Whisper.transcribe_with_time = transcribe_with_time
model = Whisper.from_pretrained(model_name_or_path)
model.params.with_token_timestamps(True)
model.params.with_print_progress(True)
return model

def transcribe_audio_whispercpp(audio_path, model=None):
if model is None:
model = load_whispercpp_model("small.en")
elif isinstance(model, str):
model = load_whispercpp_model(model)

try:
y, _ = (
ffmpeg.input(str(audio_path), threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
.run(
cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True
)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

arr = np.frombuffer(y, np.int16).flatten().astype(np.float32) / 32768.0
results = model.transcribe_with_time(arr)
return "".join([x["word"] for x in results]), json.dumps(results)



def read_wave(path, desired_sample_rate=None, force=False):
"""Reads a ".wav" file and converts to ``desired_sample_rate`` with one channel.
Expand Down Expand Up @@ -664,6 +720,8 @@ def process_segments(
model = load_vosk_model(model)
elif model == "wav2vec":
model = load_wav2vec_model()
elif model == "whispercpp":
model = load_whispercpp_model()

create_json = True
full_transcript = ""
Expand Down

0 comments on commit 58d8640

Please sign in to comment.