Skip to content

Commit

Permalink
enh : make bark.h a C header (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed May 8, 2024
1 parent 2ff7c57 commit 1c39c4c
Show file tree
Hide file tree
Showing 5 changed files with 403 additions and 284 deletions.
188 changes: 170 additions & 18 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <random>
#include <regex>
#include <string>
#include <thread>
#include <vector>

#include "bark.h"
#include "encodec.h"
Expand All @@ -30,6 +32,121 @@

static const size_t MB = 1024 * 1024;

typedef int32_t bark_token;
typedef std::vector<int32_t> bark_sequence;
typedef std::vector<std::vector<int32_t>> bark_codes;

struct bark_vocab {
using id = int32_t;
using token = std::string;

std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
};

struct gpt_layer {
// normalization
struct ggml_tensor *ln_1_g;
struct ggml_tensor *ln_1_b;

struct ggml_tensor *ln_2_g;
struct ggml_tensor *ln_2_b;

// attention
struct ggml_tensor *c_attn_attn_w;
struct ggml_tensor *c_attn_attn_b;

struct ggml_tensor *c_attn_proj_w;
struct ggml_tensor *c_attn_proj_b;

// mlp
struct ggml_tensor *c_mlp_fc_w;
struct ggml_tensor *c_mlp_fc_b;

struct ggml_tensor *c_mlp_proj_w;
struct ggml_tensor *c_mlp_proj_b;
};

struct gpt_model {
gpt_hparams hparams;

// normalization
struct ggml_tensor *ln_f_g;
struct ggml_tensor *ln_f_b;

struct ggml_tensor *wpe; // position embedding
std::vector<struct ggml_tensor *> wtes; // token embedding
std::vector<struct ggml_tensor *> lm_heads; // language model head

std::vector<gpt_layer> layers;

// key + value memory
struct ggml_tensor *memory_k;
struct ggml_tensor *memory_v;

struct ggml_context *ctx;

ggml_backend_t backend = NULL;

ggml_backend_buffer_t buffer_w;
ggml_backend_buffer_t buffer_kv;

std::map<std::string, struct ggml_tensor *> tensors;

//
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
int64_t t_main_us = 0;

//
int64_t n_sample = 0;

//
int64_t memsize = 0;
};

struct bark_model {
// The token encoders
struct gpt_model semantic_model;
struct gpt_model coarse_model;
struct gpt_model fine_model;

// The vocabulary for the semantic encoder
struct bark_vocab vocab;
};

struct bark_context {
struct bark_model text_model;

struct encodec_context *encodec_ctx;

// buffer for model evaluation
ggml_backend_buffer_t buf_compute;

// custom allocator
struct ggml_allocr *allocr = NULL;
int n_gpu_layers = 0;

std::mt19937 rng;

bark_sequence tokens;
bark_sequence semantic_tokens;

bark_codes coarse_tokens;
bark_codes fine_tokens;

std::vector<float> audio_arr;

// hyperparameters
bark_context_params params;

// encodec parameters
std::string encodec_model_path;

// statistics
bark_statistics stats;
};

class BarkProgressBar {
public:
BarkProgressBar(std::string func_name, double needed_progress) {
Expand Down Expand Up @@ -1070,22 +1187,23 @@ static bool bark_load_model_from_file(
return true;
}

struct bark_context* bark_load_model(const std::string& model_path, bark_verbosity_level verbosity, uint32_t seed) {
struct bark_context* bark_load_model(const char *model_path, bark_verbosity_level verbosity, uint32_t seed) {
int64_t t_load_start_us = ggml_time_us();

struct bark_context* bctx = new bark_context();

bctx->text_model = bark_model();
if (!bark_load_model_from_file(model_path, bctx, verbosity)) {
fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path.c_str());
std::string model_path_str(model_path);
if (!bark_load_model_from_file(model_path_str, bctx, verbosity)) {
fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path);
return nullptr;
}

bark_context_params params = bark_context_default_params();
params.verbosity = verbosity;
bctx->rng = std::mt19937(seed);
bctx->params = params;
bctx->t_load_us = ggml_time_us() - t_load_start_us;
bctx->stats.t_load_us = ggml_time_us() - t_load_start_us;

return bctx;
}
Expand Down Expand Up @@ -1629,6 +1747,7 @@ static bool bark_eval_text_encoder(struct bark_context* bctx, int n_threads) {
}

bctx->semantic_tokens = output;
bctx->stats.n_sample_semantic = model.n_sample;

return true;
}
Expand Down Expand Up @@ -1672,6 +1791,7 @@ bool bark_forward_text_encoder(struct bark_context* bctx, int n_threads) {
}

model.t_main_us = ggml_time_us() - t_main_start_us;
bctx->stats.t_semantic_us = model.t_main_us;

bark_print_statistics(&model);

Expand Down Expand Up @@ -1797,6 +1917,7 @@ static bool bark_eval_coarse_encoder(struct bark_context* bctx, int n_threads) {
}

bctx->coarse_tokens = out_coarse;
bctx->stats.n_sample_coarse = model.n_sample;

return true;
}
Expand Down Expand Up @@ -1840,6 +1961,7 @@ bool bark_forward_coarse_encoder(struct bark_context* bctx, int n_threads) {
}

model.t_main_us = ggml_time_us() - t_main_start_us;
bctx->stats.t_coarse_us = model.t_main_us;

bark_print_statistics(&model);

Expand Down Expand Up @@ -1989,6 +2111,7 @@ static bool bark_eval_fine_encoder(struct bark_context* bctx, int n_threads) {
assert(bctx->coarse_tokens.size() == in_arr.size());

bctx->fine_tokens = in_arr;
bctx->stats.n_sample_fine = model.n_sample;

return true;
}
Expand Down Expand Up @@ -2034,6 +2157,7 @@ bool bark_forward_fine_encoder(struct bark_context* bctx, int n_threads) {
}

model.t_main_us = ggml_time_us() - t_main_start_us;
bctx->stats.t_fine_us = model.t_main_us;

bark_print_statistics(&model);

Expand Down Expand Up @@ -2062,15 +2186,16 @@ static bool bark_forward_eval(struct bark_context* bctx, int n_threads) {
return true;
}

bool bark_generate_audio(struct bark_context* bctx, const std::string& text, int n_threads) {
bool bark_generate_audio(struct bark_context* bctx, const char * text, int n_threads) {
if (!bctx) {
fprintf(stderr, "%s: invalid bark context\n", __func__);
return false;
}

int64_t t_start_eval_us = ggml_time_us();

bark_tokenize_input(bctx, text);
std::string text_str(text);
bark_tokenize_input(bctx, text_str);

if (!bark_forward_eval(bctx, n_threads)) {
fprintf(stderr, "%s: failed to forward eval\n", __func__);
Expand Down Expand Up @@ -2101,8 +2226,7 @@ bool bark_generate_audio(struct bark_context* bctx, const std::string& text, int
}

bctx->audio_arr = bctx->encodec_ctx->out_audio;

bctx->t_eval_us = ggml_time_us() - t_start_eval_us;
bctx->stats.t_eval_us = ggml_time_us() - t_start_eval_us;

return true;
}
Expand Down Expand Up @@ -2230,21 +2354,21 @@ bool bark_model_weights_quantize(std::ifstream& fin, std::ofstream& fout, ggml_f
return true;
}

bool bark_model_quantize(
const std::string& fname_inp,
const std::string& fname_out,
ggml_ftype ftype) {
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
bool bark_model_quantize(const char * fname_inp, const char * fname_out, ggml_ftype ftype) {
printf("%s: loading model from '%s'\n", __func__, fname_inp);

auto fin = std::ifstream(fname_inp, std::ios::binary);
std::string fname_inp_str(fname_inp);
std::string fname_out_str(fname_out);

auto fin = std::ifstream(fname_inp_str, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp);
return false;
}

auto fout = std::ofstream(fname_out, std::ios::binary);
auto fout = std::ofstream(fname_out_str, std::ios::binary);
if (!fout) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out);
return false;
}

Expand All @@ -2253,7 +2377,7 @@ bool bark_model_quantize(
uint32_t magic;
fin.read((char*)&magic, sizeof(magic));
if (magic != GGML_FILE_MAGIC) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp);
return false;
}

Expand Down Expand Up @@ -2301,3 +2425,31 @@ bool bark_model_quantize(

return true;
}

float * bark_get_audio_data(struct bark_context *bctx) {
if (!bctx || bctx->audio_arr.empty()) {
return nullptr;
}
return bctx->audio_arr.data();
}

int bark_get_audio_data_size(struct bark_context *bctx) {
if (!bctx || bctx->audio_arr.empty()) {
return 0;
}
return bctx->audio_arr.size();
}

const bark_statistics * bark_get_statistics(struct bark_context *bctx) {
if (!bctx) {
return nullptr;
}
return &bctx->stats;
}

void bark_reset_statistics(struct bark_context *bctx) {
if (!bctx) {
return;
}
memset(&bctx->stats, 0, sizeof(bark_statistics));
}

0 comments on commit 1c39c4c

Please sign in to comment.