Skip to content

Commit

Permalink
ENH/API Add bark context params (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Sep 9, 2023
1 parent 5e2826d commit dbe6aac
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 39 deletions.
76 changes: 47 additions & 29 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ struct bark_context {
bark_codes fine_tokens;

audio_arr_t audio_arr;

float temp;
float fine_temp;

float min_eos_p;
int sliding_window_size;
int max_coarse_history;
};

struct bark_progress {
Expand All @@ -172,18 +179,41 @@ struct bark_progress {
}
};

struct bark_context * bark_new_context_with_model(struct bark_model * model) {
struct bark_context * bark_new_context_with_model(
struct bark_model * model,
struct bark_context_params params) {

if (!model) {
return nullptr;
}

bark_context * ctx = new bark_context(*model);

ctx->rng = std::mt19937(0);
ctx->rng = std::mt19937(params.seed);

ctx->temp = params.temp;
ctx->fine_temp = params.fine_temp;

ctx->max_coarse_history = params.max_coarse_history;
ctx->sliding_window_size = params.sliding_window_size;
ctx->min_eos_p = params.min_eos_p;

return ctx;
}

struct bark_context_params bark_context_default_params() {
struct bark_context_params result = {
/*.seed =*/ 0,
/*.temp =*/ 0.7,
/*.fine_temp =*/ 0.5,
/*.min_eos_p =*/ 0.2,
/*.sliding_window_size =*/ 60,
/*.max_coarse_history =*/ 630,
};

return result;
}

void bark_seed_rng(struct bark_context * ctx, int32_t seed) {
if (ctx) {
ctx->rng.seed(seed);
Expand Down Expand Up @@ -1672,11 +1702,7 @@ static void bark_print_statistics(gpt_model * model) {
printf("\n");
}

void bark_forward_text_encoder(
struct bark_context * ctx,
float temp,
float min_eos_p,
int n_threads) {
void bark_forward_text_encoder(struct bark_context * ctx, int n_threads) {
const int64_t t_main_start_us = ggml_time_us();

bark_sequence out;
Expand All @@ -1689,6 +1715,9 @@ void bark_forward_text_encoder(
auto & hparams = model->hparams;
const int n_vocab = hparams.n_out_vocab;

float min_eos_p = ctx->min_eos_p;
float temp = ctx->temp;

bark_sequence input = ctx->tokens;

std::vector<float> logits;
Expand Down Expand Up @@ -1733,12 +1762,7 @@ void bark_forward_text_encoder(
bark_print_statistics(model);
}

void bark_forward_coarse_encoder(
struct bark_context * ctx,
int max_coarse_history,
int sliding_window_size,
float temp,
int n_threads) {
void bark_forward_coarse_encoder(struct bark_context * ctx, int n_threads) {
const int64_t t_main_start_us = ggml_time_us();

bark_codes out_coarse;
Expand All @@ -1747,6 +1771,10 @@ void bark_forward_coarse_encoder(
bark_progress progress;
progress.func = __func__;

int max_coarse_history = ctx->max_coarse_history;
int sliding_window_size = ctx->sliding_window_size;
float temp = ctx->temp;

float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS;
int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio);

Expand Down Expand Up @@ -1851,12 +1879,14 @@ void bark_forward_coarse_encoder(

}

void bark_forward_fine_encoder(struct bark_context * ctx,float temp, int n_threads) {
void bark_forward_fine_encoder(struct bark_context * ctx, int n_threads) {
// input shape: (N, n_codes)
const int64_t t_main_start_us = ggml_time_us();

bark_codes input = ctx->coarse_tokens;

float temp = ctx->fine_temp;

std::vector<float> logits;
logits.resize(1024*1056);

Expand Down Expand Up @@ -2063,26 +2093,14 @@ int bark_generate_audio(
const char * text,
const char * dest_wav_path,
int n_threads) {
const float temp = 0.7;
const float fine_temp = 0.5;

const int sliding_window_size = 60;
const int max_coarse_history = 630;

const float min_eos_p = 0.2;

// tokenize input (bert tokenizer)
bark_tokenize_input(ctx, text);

// forward pass
bark_forward_text_encoder(ctx, temp, min_eos_p, n_threads);
bark_forward_coarse_encoder(ctx, max_coarse_history, sliding_window_size, temp, n_threads);
bark_forward_fine_encoder(ctx, fine_temp, n_threads);
bark_forward_text_encoder (ctx, n_threads);
bark_forward_coarse_encoder(ctx, n_threads);
bark_forward_fine_encoder (ctx, n_threads);

// encode audio
bark_forward_encodec(ctx);

// write wav file
write_wav_on_disk(ctx->audio_arr, dest_wav_path);

return 0;
Expand Down
27 changes: 18 additions & 9 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,34 @@ extern "C" {
// C interface
//

typedef int32_t bark_token;

struct bark_context;
struct bark_progress;

struct bark_context_params {
uint32_t seed; // RNG seed

float temp; // Temperature for sampling (text and coarse encoders)
float fine_temp; // Temperature for sampling (fine encoder)

float min_eos_p; // Minimum probability for EOS token (text encoder)
int sliding_window_size; // Sliding window size for coarse encoder
int max_coarse_history; // Max history for coarse encoder
};

struct bark_model;
struct bark_vocab;

typedef int32_t bark_token;

struct gpt_hparams;
struct gpt_layer;
struct gpt_model;

BARK_API struct bark_context * bark_new_context_with_model(struct bark_model * model);
BARK_API struct bark_context_params bark_context_default_params(void);

BARK_API struct bark_context * bark_new_context_with_model(
struct bark_model * model,
struct bark_context_params params);

BARK_API void bark_seed_rng(struct bark_context * ctx, int32_t seed);

Expand Down Expand Up @@ -128,20 +143,14 @@ extern "C" {

void bark_forward_text_encoder(
struct bark_context * ctx,
float temp,
float min_eos_p,
int n_threads);

void bark_forward_coarse_encoder(
struct bark_context * ctx,
int max_coarse_history,
int sliding_window_size,
float temp,
int n_threads);

void bark_forward_fine_encoder(
struct bark_context * ctx,
float temp,
int n_threads);

void bark_forward_encodec(struct bark_context * ctx);
Expand Down
3 changes: 2 additions & 1 deletion examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ std::tuple<struct bark_model *, struct bark_context *> bark_init_from_params(bar
return std::make_tuple(nullptr, nullptr);
}

bark_context * bctx = bark_new_context_with_model(model);
bark_context_params bctx_params = bark_context_default_params();
bark_context * bctx = bark_new_context_with_model(model, bctx_params);
if (bctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model_path.c_str());
bark_free_model(model);
Expand Down

0 comments on commit dbe6aac

Please sign in to comment.