Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Add history prompts for custom voices #84

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
bark prompt load
  • Loading branch information
PABannier committed Aug 17, 2023
commit 849727384f24587d92b80a1c74426b24dea41559
125 changes: 125 additions & 0 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,131 @@ bool bark_vocab_load(const std::string& fname, bark_vocab& vocab, int32_t expect
return true;
}

bool bark_prompt_load(const std::string & fname, bark_history_prompts & history_prompts) {
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: faield to open '%s'\n", __func__, fname.c_str());
return false;
}

// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != GGML_FILE_MAGIC) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}

// upper bound on the ctx size needed to store all prompts (not very large)
size_t ctx_size = 10*MB;

// create the ggml context
{
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};

history_prompts.ctx = ggml_init(params);
if (!history_prompts.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}

auto & ctx = history_prompts.ctx;

int32_t n_prompts;
read_safe(fin, n_prompts);

std::string prompt_name;
std::vector<char> tmp;

tmp.reserve(128);

for (int i = 0; i < n_prompts; i++) {
uint32_t len;
read_safe(fin, len);

if (len > 0) {
tmp.resize(len);
fin.read(&tmp[0], tmp.size()); // read to buffer
prompt_name.assign(&tmp[0], tmp.size());
} else {
fprintf(stderr, "%s: invalid prompt name\n", __func__);
}

int64_t memsize = 0;

struct ggml_tensor * semantic_prompt;
struct ggml_tensor * coarse_prompt;
struct ggml_tensor * fine_prompt;

std::map<std::string, struct ggml_tensor *> prompt_tensors = {
{ "semantic_prompt", semantic_prompt },
{ "coarse_prompt" , coarse_prompt },
{ "fine_prompt" , fine_prompt },
};

int32_t n_keys;
read_safe(fin, n_keys);

for (int k = 0; k < n_keys; k++) {
int32_t n_dims;
int32_t length;

read_safe(fin, n_dims);
read_safe(fin, length);

int64_t nelements = 1;
int64_t ne[4] = { 1, 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
read_safe(fin, ne[i]);
nelements *= ne[i];
}

std::string name(length, 0);
fin.read(&name[0], length);

if ((name != "semantic_prompt") && (name != "coarse_prompt") && (name != "fine_prompt")) {
fprintf(stderr, "%s: tensor '%s' has an unknown key: '%s'\n", __func__, prompt_name, name);
return false;
}

const size_t bpe = ggml_type_size(GGML_TYPE_I32);

auto & tensor = prompt_tensors[name];
tensor = ggml_new_tensor(ctx, GGML_TYPE_I32, 4, ne);

if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}

fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));

memsize += ggml_nbytes(tensor);
}

struct bark_voice voice = {
/*.name =*/ prompt_name,
/*.semantic_prompt =*/ prompt_tensors["semantic_prompt"],
/*.coarse_prompt =*/ prompt_tensors["coarse_prompt"],
/*.fine_prompt =*/ prompt_tensors["fine_prompt"],
/*.memsize =*/ memsize,
};

history_prompts.voices[prompt_name] = &voice;
history_prompts.memsize += memsize;
}

return true;
}

bool gpt_model_load(const std::string& fname, gpt_model& model) {
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
Expand Down
21 changes: 21 additions & 0 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ struct gpt_hparams {
int32_t n_codes_given = 1;
};

struct bark_voice {
std::string name;

struct ggml_tensor * semantic_prompt;
struct ggml_tensor * coarse_prompt;
struct ggml_tensor * fine_prompt;

int64_t memsize;
};

struct bark_history_prompts {
struct ggml_context * ctx;

std::map<std::string, struct bark_voice *> voices;

int64_t memsize;
};

struct bark_vocab {
using id = int32_t;
using token = std::string;
Expand Down Expand Up @@ -159,6 +177,9 @@ struct bark_model {
// vocab
bark_vocab vocab;

// history prompts
bark_history_prompts history_prompts;

int64_t memsize = 0;
};

Expand Down