Skip to content

Commit

Permalink
Merge pull request #17 from occ-ai/roy.adaptive_delay
Browse files Browse the repository at this point in the history
Fix the adaptive delay
  • Loading branch information
royshil committed May 16, 2024
2 parents 12dfa80 + a1a2a20 commit a8d173c
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 61 deletions.
2 changes: 1 addition & 1 deletion buildspec.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
}
},
"name": "obs-cleanstream",
"version": "0.0.6",
"version": "0.0.7",
"author": "Roy Shilkrot",
"website": "https://github.com/occ-ai/obs-cleanstream/",
"email": "[email protected]",
Expand Down
9 changes: 3 additions & 6 deletions src/cleanstream-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ struct cleanstream_data {
uint32_t sample_rate; // input sample rate
// How many input frames (in input sample rate) are needed for the next whisper frame
size_t frames;
// How many ms/frames are needed to overlap with the next whisper frame
size_t overlap_frames;
size_t overlap_ms;
// How many frames were processed in the last whisper frame (this is dynamic)
size_t last_num_frames;
int current_result;
uint64_t current_result_end_timestamp;
uint64_t current_result_start_timestamp;
uint32_t delay_ms;

/* Silero VAD */
std::unique_ptr<VadIterator> vad;
Expand Down Expand Up @@ -76,7 +74,6 @@ struct cleanstream_data {
size_t audioFilePointer = 0;

float filler_p_threshold;
bool do_silence;
bool vad_enabled;
int log_level;
const char *detect_regex;
Expand Down
113 changes: 69 additions & 44 deletions src/cleanstream-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
#define BUFFER_SIZE_MSEC 1010
// at 16Khz, 1010 msec is 16160 frames
#define WHISPER_FRAME_SIZE 16160
// overlap in msec
#define OVERLAP_SIZE_MSEC 340
// initial delay in msec
#define INITIAL_DELAY_MSEC 500

#define VAD_THOLD 0.0001f
#define FREQ_THOLD 100.0f
Expand All @@ -58,44 +58,56 @@ struct obs_audio_data *cleanstream_filter_audio(void *data, struct obs_audio_dat
return audio;
}

std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock
size_t input_buffer_size = 0;
{
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock

if (audio != nullptr && audio->frames > 0) {
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
audio->frames * sizeof(float));
if (audio != nullptr && audio->frames > 0) {
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
audio->frames * sizeof(float));
}
// push audio packet info (timestamp/frame count) to info circlebuf
struct cleanstream_audio_info info = {0};
info.frames = audio->frames; // number of frames in this packet
info.timestamp = audio->timestamp; // timestamp of this packet
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
}
// push audio packet info (timestamp/frame count) to info circlebuf
struct cleanstream_audio_info info = {0};
info.frames = audio->frames; // number of frames in this packet
info.timestamp = audio->timestamp; // timestamp of this packet
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
input_buffer_size = gf->input_buffers[0].size;
}

// check the size of the input buffer - if it's more than 1500ms worth of audio, start playback
if (gf->input_buffers[0].size > 1500 * gf->sample_rate * sizeof(float) / 1000) {
// check the size of the input buffer - if it's more than <delay>ms worth of audio, start playback
if (input_buffer_size > gf->delay_ms * gf->sample_rate * sizeof(float) / 1000) {
// find needed number of frames from the incoming audio
size_t num_frames_needed = audio->frames;

std::vector<float> temporary_buffers[MAX_AUDIO_CHANNELS];
uint64_t timestamp = 0;

while (temporary_buffers[0].size() < num_frames_needed) {
struct cleanstream_audio_info info_out = {0};
{
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
// pop from input buffers to get audio packet info
circlebuf_pop_front(&gf->info_buffer, &info_out, sizeof(info_out));

// pop from input circlebuf to audio data
for (size_t i = 0; i < gf->channels; i++) {
// increase the size of the temporary buffer to hold the incoming audio in addition
// to the existing audio on the temporary buffer
temporary_buffers[i].resize(temporary_buffers[i].size() +
info_out.frames);
circlebuf_pop_front(&gf->input_buffers[i],
temporary_buffers[i].data() +
temporary_buffers[i].size() -
info_out.frames,
info_out.frames * sizeof(float));
while (temporary_buffers[0].size() < num_frames_needed) {
struct cleanstream_audio_info info_out = {0};
// pop from input buffers to get audio packet info
circlebuf_pop_front(&gf->info_buffer, &info_out, sizeof(info_out));
if (timestamp == 0) {
timestamp = info_out.timestamp;
}

// pop from input circlebuf to audio data
for (size_t i = 0; i < gf->channels; i++) {
// increase the size of the temporary buffer to hold the incoming audio in addition
// to the existing audio on the temporary buffer
temporary_buffers[i].resize(temporary_buffers[i].size() +
info_out.frames);
circlebuf_pop_front(&gf->input_buffers[i],
temporary_buffers[i].data() +
temporary_buffers[i].size() -
info_out.frames,
info_out.frames * sizeof(float));
}
}
}
const size_t num_frames = temporary_buffers[0].size();
Expand All @@ -105,7 +117,18 @@ struct obs_audio_data *cleanstream_filter_audio(void *data, struct obs_audio_dat
da_resize(gf->output_data, frames_size_bytes * gf->channels);
memset(gf->output_data.array, 0, frames_size_bytes * gf->channels);

if (gf->current_result == DetectionResult::DETECTION_RESULT_BEEP) {
int inference_result = DetectionResult::DETECTION_RESULT_UNKNOWN;
uint64_t inference_result_start_timestamp = 0;
uint64_t inference_result_end_timestamp = 0;
{
std::lock_guard<std::mutex> outbuf_lock(gf->whisper_outbuf_mutex);
inference_result = gf->current_result;
inference_result_start_timestamp = gf->current_result_start_timestamp;
inference_result_end_timestamp = gf->current_result_end_timestamp;
}

if (timestamp > inference_result_start_timestamp &&
timestamp < inference_result_end_timestamp) {
if (gf->replace_sound == REPLACE_SOUNDS_SILENCE) {
// set the audio to 0
for (size_t i = 0; i < gf->channels; i++) {
Expand Down Expand Up @@ -207,9 +230,12 @@ void cleanstream_update(void *data, obs_data_t *s)
gf->replace_sound = obs_data_get_int(s, "replace_sound");
gf->filler_p_threshold = (float)obs_data_get_double(s, "filler_p_threshold");
gf->log_level = (int)obs_data_get_int(s, "log_level");
gf->do_silence = obs_data_get_bool(s, "do_silence");
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
gf->log_words = obs_data_get_bool(s, "log_words");
gf->delay_ms = BUFFER_SIZE_MSEC + INITIAL_DELAY_MSEC;
gf->current_result = DetectionResult::DETECTION_RESULT_UNKNOWN;
gf->current_result_start_timestamp = 0;
gf->current_result_end_timestamp = 0;

obs_log(gf->log_level, "update whisper model");
update_whisper_model(gf, s);
Expand Down Expand Up @@ -260,7 +286,10 @@ void *cleanstream_create(obs_data_t *settings, obs_source_t *filter)
gf->channels = audio_output_get_channels(obs_get_audio());
gf->sample_rate = audio_output_get_sample_rate(obs_get_audio());
gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)BUFFER_SIZE_MSEC));
gf->last_num_frames = 0;
gf->delay_ms = BUFFER_SIZE_MSEC + INITIAL_DELAY_MSEC;
gf->current_result = DetectionResult::DETECTION_RESULT_UNKNOWN;
gf->current_result_start_timestamp = 0;
gf->current_result_end_timestamp = 0;

for (size_t i = 0; i < MAX_AUDIO_CHANNELS; i++) {
circlebuf_init(&gf->input_buffers[i]);
Expand All @@ -283,10 +312,8 @@ void *cleanstream_create(obs_data_t *settings, obs_source_t *filter)
gf->whisper_model_path = std::string(""); // The update function will set the model path
gf->whisper_context = nullptr;

gf->overlap_ms = OVERLAP_SIZE_MSEC;
gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
obs_log(LOG_INFO, "CleanStream filter: channels %d, frames %d, sample_rate %d",
(int)gf->channels, (int)gf->frames, gf->sample_rate);
obs_log(LOG_INFO, "CleanStream filter: channels %d, sample_rate %d", (int)gf->channels,
gf->sample_rate);

struct resample_info src, dst;
src.samples_per_sec = gf->sample_rate;
Expand Down Expand Up @@ -356,7 +383,6 @@ void cleanstream_defaults(obs_data_t *s)
obs_data_set_default_int(s, "replace_sound", REPLACE_SOUNDS_SILENCE);
obs_data_set_default_bool(s, "advanced_settings", false);
obs_data_set_default_double(s, "filler_p_threshold", 0.75);
obs_data_set_default_bool(s, "do_silence", true);
obs_data_set_default_bool(s, "vad_enabled", true);
obs_data_set_default_int(s, "log_level", LOG_DEBUG);
obs_data_set_default_bool(s, "log_words", false);
Expand All @@ -365,10 +391,10 @@ void cleanstream_defaults(obs_data_t *s)

// Whisper parameters
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
obs_data_set_default_string(s, "initial_prompt", "uhm, Uh, um, Uhh, um. um... uh. uh... ");
obs_data_set_default_string(s, "initial_prompt", "");
obs_data_set_default_int(s, "n_threads", 4);
obs_data_set_default_int(s, "n_max_text_ctx", 16384);
obs_data_set_default_bool(s, "no_context", true);
obs_data_set_default_bool(s, "no_context", false);
obs_data_set_default_bool(s, "single_segment", true);
obs_data_set_default_bool(s, "print_special", false);
obs_data_set_default_bool(s, "print_progress", false);
Expand All @@ -379,7 +405,7 @@ void cleanstream_defaults(obs_data_t *s)
obs_data_set_default_double(s, "thold_ptsum", 0.01);
obs_data_set_default_int(s, "max_len", 0);
obs_data_set_default_bool(s, "split_on_word", false);
obs_data_set_default_int(s, "max_tokens", 3);
obs_data_set_default_int(s, "max_tokens", 7);
obs_data_set_default_bool(s, "speed_up", false);
obs_data_set_default_bool(s, "suppress_blank", true);
obs_data_set_default_bool(s, "suppress_non_speech_tokens", true);
Expand Down Expand Up @@ -479,8 +505,8 @@ obs_properties_t *cleanstream_properties(void *data)
// If advanced settings is enabled, show the advanced settings group
const bool show_hide = obs_data_get_bool(settings, "advanced_settings");
for (const std::string &prop_name :
{"whisper_params_group", "log_words", "filler_p_threshold", "do_silence",
"vad_enabled", "log_level"}) {
{"whisper_params_group", "log_words", "filler_p_threshold", "vad_enabled",
"log_level"}) {
obs_property_set_visible(obs_properties_get(props, prop_name.c_str()),
show_hide);
}
Expand All @@ -489,7 +515,6 @@ obs_properties_t *cleanstream_properties(void *data)

obs_properties_add_float_slider(ppts, "filler_p_threshold", MT_("filler_p_threshold"), 0.0f,
1.0f, 0.05f);
obs_properties_add_bool(ppts, "do_silence", MT_("do_silence"));
obs_properties_add_bool(ppts, "vad_enabled", MT_("vad_enabled"));
obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
Expand Down
38 changes: 28 additions & 10 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,26 +305,31 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data,
long long process_audio_from_buffer(struct cleanstream_data *gf)
{
uint64_t start_timestamp = 0;
uint64_t end_timestamp = 0;

{
// scoped lock the buffer mutex
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);

// copy gf->frames from the end of the input buffer to the copy_buffers
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_peek_front(&gf->input_buffers[c], gf->copy_buffers[c],
gf->frames * sizeof(float));
circlebuf_peek_back(&gf->input_buffers[c], gf->copy_buffers[c],
gf->frames * sizeof(float));
}

// peek at the info_buffer to get the timestamp of the first info
// peek at the info_buffer to get the timestamp of the last info
struct cleanstream_audio_info info_from_buf = {0};
circlebuf_peek_front(&gf->info_buffer, &info_from_buf,
sizeof(struct cleanstream_audio_info));
start_timestamp = info_from_buf.timestamp;
circlebuf_peek_back(&gf->info_buffer, &info_from_buf,
sizeof(struct cleanstream_audio_info));
end_timestamp = info_from_buf.timestamp;
start_timestamp =
end_timestamp - (int)(gf->frames * 1000 / gf->sample_rate) * 1000000;
}

obs_log(gf->log_level, "processing %lu frames (%d ms), start timestamp %llu ", gf->frames,
(int)(gf->frames * 1000 / gf->sample_rate), start_timestamp);
obs_log(gf->log_level,
"processing %lu frames (%d ms), start timestamp %llu, end timestamp %llu ",
gf->frames, (int)(gf->frames * 1000 / gf->sample_rate), start_timestamp,
end_timestamp);

// time the audio processing
auto start = std::chrono::high_resolution_clock::now();
Expand All @@ -349,8 +354,7 @@ long long process_audio_from_buffer(struct cleanstream_data *gf)

std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
if (stamps.size() == 0) {
obs_log(gf->log_level, "VAD detected no speech in %d frames",
whisper_buffer_16khz);
obs_log(gf->log_level, "VAD detected no speech");
skipped_inference = true;
}
}
Expand All @@ -362,8 +366,13 @@ long long process_audio_from_buffer(struct cleanstream_data *gf)
{
std::lock_guard<std::mutex> lock(gf->whisper_outbuf_mutex);
gf->current_result = inference_result;
if (gf->current_result == DETECTION_RESULT_BEEP) {
gf->current_result_start_timestamp = start_timestamp;
gf->current_result_end_timestamp = end_timestamp;
}
}
} else {
gf->current_result = DETECTION_RESULT_SILENCE;
if (gf->log_words) {
obs_log(LOG_INFO, "skipping inference");
}
Expand All @@ -377,6 +386,15 @@ long long process_audio_from_buffer(struct cleanstream_data *gf)
obs_log(gf->log_level, "audio processing of %u ms new data took %d ms", audio_processed_ms,
(int)duration);

if (duration > (gf->delay_ms - audio_processed_ms)) {
obs_log(gf->log_level,
"audio processing (%d ms) longer than delay (%lu ms), increase delay",
(int)duration, gf->delay_ms);
gf->delay_ms += 100;
} else {
gf->delay_ms -= 100;
}

return duration;
}

Expand Down

0 comments on commit a8d173c

Please sign in to comment.