From 3a81f7a898531a7377bff279075339e889c9c312 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 9 Sep 2023 21:10:13 +0200 Subject: [PATCH] initial commit --- bark.cpp | 76 +++++++++++++++++++++++++++++------------------ bark.h | 27 +++++++++++------ examples/main.cpp | 3 +- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/bark.cpp b/bark.cpp index a4aadf9..03053ef 100644 --- a/bark.cpp +++ b/bark.cpp @@ -148,6 +148,13 @@ struct bark_context { bark_codes fine_tokens; audio_arr_t audio_arr; + + float temp; + float fine_temp; + + float min_eos_p; + int sliding_window_size; + int max_coarse_history; }; struct bark_progress { @@ -172,18 +179,41 @@ struct bark_progress { } }; -struct bark_context * bark_new_context_with_model(struct bark_model * model) { +struct bark_context * bark_new_context_with_model( + struct bark_model * model, + struct bark_context_params params) { + if (!model) { return nullptr; } bark_context * ctx = new bark_context(*model); - ctx->rng = std::mt19937(0); + ctx->rng = std::mt19937(params.seed); + + ctx->temp = params.temp; + ctx->fine_temp = params.fine_temp; + + ctx->max_coarse_history = params.max_coarse_history; + ctx->sliding_window_size = params.sliding_window_size; + ctx->min_eos_p = params.min_eos_p; return ctx; } +struct bark_context_params bark_context_default_params() { + struct bark_context_params result = { + /*.seed =*/ 0, + /*.temp =*/ 0.7, + /*.fine_temp =*/ 0.5, + /*.min_eos_p =*/ 0.2, + /*.sliding_window_size =*/ 60, + /*.max_coarse_history =*/ 630, + }; + + return result; +} + void bark_seed_rng(struct bark_context * ctx, int32_t seed) { if (ctx) { ctx->rng.seed(seed); @@ -1672,11 +1702,7 @@ static void bark_print_statistics(gpt_model * model) { printf("\n"); } -void bark_forward_text_encoder( - struct bark_context * ctx, - float temp, - float min_eos_p, - int n_threads) { +void bark_forward_text_encoder(struct bark_context * ctx, int n_threads) { const int64_t t_main_start_us = ggml_time_us(); bark_sequence out; @@ -1689,6 +1715,9 @@ void bark_forward_text_encoder( auto & hparams = model->hparams; const int n_vocab = hparams.n_out_vocab; + float min_eos_p = ctx->min_eos_p; + float temp = ctx->temp; + bark_sequence input = ctx->tokens; std::vector logits; @@ -1733,12 +1762,7 @@ void bark_forward_text_encoder( bark_print_statistics(model); } -void bark_forward_coarse_encoder( - struct bark_context * ctx, - int max_coarse_history, - int sliding_window_size, - float temp, - int n_threads) { +void bark_forward_coarse_encoder(struct bark_context * ctx, int n_threads) { const int64_t t_main_start_us = ggml_time_us(); bark_codes out_coarse; @@ -1747,6 +1771,10 @@ void bark_forward_coarse_encoder( bark_progress progress; progress.func = __func__; + int max_coarse_history = ctx->max_coarse_history; + int sliding_window_size = ctx->sliding_window_size; + float temp = ctx->temp; + float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS; int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio); @@ -1851,12 +1879,14 @@ void bark_forward_coarse_encoder( } -void bark_forward_fine_encoder(struct bark_context * ctx,float temp, int n_threads) { +void bark_forward_fine_encoder(struct bark_context * ctx, int n_threads) { // input shape: (N, n_codes) const int64_t t_main_start_us = ggml_time_us(); bark_codes input = ctx->coarse_tokens; + float temp = ctx->fine_temp; + std::vector logits; logits.resize(1024*1056); @@ -2063,26 +2093,14 @@ int bark_generate_audio( const char * text, const char * dest_wav_path, int n_threads) { - const float temp = 0.7; - const float fine_temp = 0.5; - - const int sliding_window_size = 60; - const int max_coarse_history = 630; - - const float min_eos_p = 0.2; - - // tokenize input (bert tokenizer) bark_tokenize_input(ctx, text); - // forward pass - bark_forward_text_encoder(ctx, temp, min_eos_p, n_threads); - bark_forward_coarse_encoder(ctx, max_coarse_history, sliding_window_size, temp, n_threads); - bark_forward_fine_encoder(ctx, fine_temp, n_threads); + bark_forward_text_encoder (ctx, n_threads); + bark_forward_coarse_encoder(ctx, n_threads); + bark_forward_fine_encoder (ctx, n_threads); - // encode audio bark_forward_encodec(ctx); - // write wav file write_wav_on_disk(ctx->audio_arr, dest_wav_path); return 0; diff --git a/bark.h b/bark.h index 71e2af8..33001fa 100644 --- a/bark.h +++ b/bark.h @@ -52,19 +52,34 @@ extern "C" { // C interface // + typedef int32_t bark_token; + struct bark_context; struct bark_progress; + struct bark_context_params { + uint32_t seed; // RNG seed + + float temp; // Temperature for sampling (text and coarse encoders) + float fine_temp; // Temperature for sampling (fine encoder) + + float min_eos_p; // Minimum probability for EOS token (text encoder) + int sliding_window_size; // Sliding window size for coarse encoder + int max_coarse_history; // Max history for coarse encoder + }; + struct bark_model; struct bark_vocab; - typedef int32_t bark_token; - struct gpt_hparams; struct gpt_layer; struct gpt_model; - BARK_API struct bark_context * bark_new_context_with_model(struct bark_model * model); + BARK_API struct bark_context_params bark_context_default_params(void); + + BARK_API struct bark_context * bark_new_context_with_model( + struct bark_model * model, + struct bark_context_params params); BARK_API void bark_seed_rng(struct bark_context * ctx, int32_t seed); @@ -128,20 +143,14 @@ extern "C" { void bark_forward_text_encoder( struct bark_context * ctx, - float temp, - float min_eos_p, int n_threads); void bark_forward_coarse_encoder( struct bark_context * ctx, - int max_coarse_history, - int sliding_window_size, - float temp, int n_threads); void bark_forward_fine_encoder( struct bark_context * ctx, - float temp, int n_threads); void bark_forward_encodec(struct bark_context * ctx); diff --git a/examples/main.cpp b/examples/main.cpp index 1f5c60f..7a1dab9 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -66,7 +66,8 @@ std::tuple bark_init_from_params(bar return std::make_tuple(nullptr, nullptr); } - bark_context * bctx = bark_new_context_with_model(model); + bark_context_params bctx_params = bark_context_default_params(); + bark_context * bctx = bark_new_context_with_model(model, bctx_params); if (bctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model_path.c_str()); bark_free_model(model);