Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH/API Add bark context params #106

Merged
merged 1 commit into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
initial commit
  • Loading branch information
PABannier committed Sep 9, 2023
commit 3a81f7a898531a7377bff279075339e889c9c312
76 changes: 47 additions & 29 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ struct bark_context {
bark_codes fine_tokens;

audio_arr_t audio_arr;

float temp;
float fine_temp;

float min_eos_p;
int sliding_window_size;
int max_coarse_history;
};

struct bark_progress {
Expand All @@ -172,18 +179,41 @@ struct bark_progress {
}
};

struct bark_context * bark_new_context_with_model(struct bark_model * model) {
struct bark_context * bark_new_context_with_model(
struct bark_model * model,
struct bark_context_params params) {

if (!model) {
return nullptr;
}

bark_context * ctx = new bark_context(*model);

ctx->rng = std::mt19937(0);
ctx->rng = std::mt19937(params.seed);

ctx->temp = params.temp;
ctx->fine_temp = params.fine_temp;

ctx->max_coarse_history = params.max_coarse_history;
ctx->sliding_window_size = params.sliding_window_size;
ctx->min_eos_p = params.min_eos_p;

return ctx;
}

struct bark_context_params bark_context_default_params() {
struct bark_context_params result = {
/*.seed =*/ 0,
/*.temp =*/ 0.7,
/*.fine_temp =*/ 0.5,
/*.min_eos_p =*/ 0.2,
/*.sliding_window_size =*/ 60,
/*.max_coarse_history =*/ 630,
};

return result;
}

void bark_seed_rng(struct bark_context * ctx, int32_t seed) {
if (ctx) {
ctx->rng.seed(seed);
Expand Down Expand Up @@ -1672,11 +1702,7 @@ static void bark_print_statistics(gpt_model * model) {
printf("\n");
}

void bark_forward_text_encoder(
struct bark_context * ctx,
float temp,
float min_eos_p,
int n_threads) {
void bark_forward_text_encoder(struct bark_context * ctx, int n_threads) {
const int64_t t_main_start_us = ggml_time_us();

bark_sequence out;
Expand All @@ -1689,6 +1715,9 @@ void bark_forward_text_encoder(
auto & hparams = model->hparams;
const int n_vocab = hparams.n_out_vocab;

float min_eos_p = ctx->min_eos_p;
float temp = ctx->temp;

bark_sequence input = ctx->tokens;

std::vector<float> logits;
Expand Down Expand Up @@ -1733,12 +1762,7 @@ void bark_forward_text_encoder(
bark_print_statistics(model);
}

void bark_forward_coarse_encoder(
struct bark_context * ctx,
int max_coarse_history,
int sliding_window_size,
float temp,
int n_threads) {
void bark_forward_coarse_encoder(struct bark_context * ctx, int n_threads) {
const int64_t t_main_start_us = ggml_time_us();

bark_codes out_coarse;
Expand All @@ -1747,6 +1771,10 @@ void bark_forward_coarse_encoder(
bark_progress progress;
progress.func = __func__;

int max_coarse_history = ctx->max_coarse_history;
int sliding_window_size = ctx->sliding_window_size;
float temp = ctx->temp;

float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS;
int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt the converstion to int already floor it? (seems to be always positive)


Expand Down Expand Up @@ -1851,12 +1879,14 @@ void bark_forward_coarse_encoder(

}

void bark_forward_fine_encoder(struct bark_context * ctx,float temp, int n_threads) {
void bark_forward_fine_encoder(struct bark_context * ctx, int n_threads) {
// input shape: (N, n_codes)
const int64_t t_main_start_us = ggml_time_us();

bark_codes input = ctx->coarse_tokens;

float temp = ctx->fine_temp;

std::vector<float> logits;
logits.resize(1024*1056);

Expand Down Expand Up @@ -2063,26 +2093,14 @@ int bark_generate_audio(
const char * text,
const char * dest_wav_path,
int n_threads) {
const float temp = 0.7;
const float fine_temp = 0.5;

const int sliding_window_size = 60;
const int max_coarse_history = 630;

const float min_eos_p = 0.2;

// tokenize input (bert tokenizer)
bark_tokenize_input(ctx, text);

// forward pass
bark_forward_text_encoder(ctx, temp, min_eos_p, n_threads);
bark_forward_coarse_encoder(ctx, max_coarse_history, sliding_window_size, temp, n_threads);
bark_forward_fine_encoder(ctx, fine_temp, n_threads);
bark_forward_text_encoder (ctx, n_threads);
bark_forward_coarse_encoder(ctx, n_threads);
bark_forward_fine_encoder (ctx, n_threads);

// encode audio
bark_forward_encodec(ctx);

// write wav file
write_wav_on_disk(ctx->audio_arr, dest_wav_path);

return 0;
Expand Down
27 changes: 18 additions & 9 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,34 @@ extern "C" {
// C interface
//

typedef int32_t bark_token;

struct bark_context;
struct bark_progress;

struct bark_context_params {
uint32_t seed; // RNG seed

float temp; // Temperature for sampling (text and coarse encoders)
float fine_temp; // Temperature for sampling (fine encoder)

float min_eos_p; // Minimum probability for EOS token (text encoder)
int sliding_window_size; // Sliding window size for coarse encoder
int max_coarse_history; // Max history for coarse encoder
};

struct bark_model;
struct bark_vocab;

typedef int32_t bark_token;

struct gpt_hparams;
struct gpt_layer;
struct gpt_model;

BARK_API struct bark_context * bark_new_context_with_model(struct bark_model * model);
BARK_API struct bark_context_params bark_context_default_params(void);

BARK_API struct bark_context * bark_new_context_with_model(
struct bark_model * model,
struct bark_context_params params);

BARK_API void bark_seed_rng(struct bark_context * ctx, int32_t seed);

Expand Down Expand Up @@ -128,20 +143,14 @@ extern "C" {

void bark_forward_text_encoder(
struct bark_context * ctx,
float temp,
float min_eos_p,
int n_threads);

void bark_forward_coarse_encoder(
struct bark_context * ctx,
int max_coarse_history,
int sliding_window_size,
float temp,
int n_threads);

void bark_forward_fine_encoder(
struct bark_context * ctx,
float temp,
int n_threads);

void bark_forward_encodec(struct bark_context * ctx);
Expand Down
3 changes: 2 additions & 1 deletion examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ std::tuple<struct bark_model *, struct bark_context *> bark_init_from_params(bar
return std::make_tuple(nullptr, nullptr);
}

bark_context * bctx = bark_new_context_with_model(model);
bark_context_params bctx_params = bark_context_default_params();
bark_context * bctx = bark_new_context_with_model(model, bctx_params);
if (bctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model_path.c_str());
bark_free_model(model);
Expand Down