Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Add history prompts for custom voices #84

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
semantic tokens
  • Loading branch information
PABannier committed Aug 19, 2023
commit c4753ce1200997cb10527d0aa91fe2c02e0750a6
77 changes: 63 additions & 14 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ bool bark_prompt_load(const std::string & fname, bark_history_prompts & history_
fin.read(&name[0], length);

if ((name != "semantic_prompt") && (name != "coarse_prompt") && (name != "fine_prompt")) {
fprintf(stderr, "%s: tensor '%s' has an unknown key: '%s'\n", __func__, prompt_name, name);
fprintf(stderr, "%s: tensor '%s' has an unknown key: '%s'\n", __func__, prompt_name.c_str(), name.c_str());
return false;
}

Expand Down Expand Up @@ -1468,20 +1468,65 @@ bark_sequence bark_tokenize_input(const char * text, const bark_vocab & vocab, i

tokens.resize(max_ctx_size);

// semantic history
for (int i = 0; i < 256; i++)
tokens.push_back(SEMANTIC_PAD_TOKEN);
tokens.push_back(SEMANTIC_INFER_TOKEN);
return tokens;
}

assert(tokens.size() == 256 + 256 + 1);
int bark_get_input_sequence(
struct bark_history_prompts * history_prompts,
std::vector<bark_vocab::id> & tokens,
std::vector<bark_vocab::id> & out,
const std::string & voice) {
BARK_ASSERT(tokens.size() == 256);

return tokens;
out.resize(513);

struct bark_voice * history_prompt = nullptr;
if (!voice.empty()) {
if (history_prompts->voices.find(voice) != history_prompts->voices.end()) {
history_prompt = history_prompts->voices[voice];
} else {
fprintf(stderr, "Could not find voice '%s'\n", voice.c_str());
return false;
}
}

auto & ctx = history_prompts->ctx;
struct ggml_cgraph gf = {};

struct ggml_tensor * semantic_history = nullptr;
if (history_prompt) {
semantic_history = history_prompt->semantic_prompt;
if (semantic_history->ne[0] >= 256) {
size_t offset = (semantic_history->ne[0] - 256) * semantic_history->nb[0];
semantic_history = ggml_view_1d(ctx, semantic_history, 256, offset);
} else {
// constant padding
struct ggml_tensor * out = ggml_new_tensor_1d(ctx, semantic_history->type, 256);
out = ggml_set_f32(out, SEMANTIC_PAD_TOKEN);
semantic_history = ggml_set_1d(ctx, out, semantic_history, 0);
}
} else {
semantic_history = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 256);
semantic_history = ggml_set_i32(semantic_history, SEMANTIC_PAD_TOKEN);
}

// concatenate tokens, semantic_history and [SEMANTIC_INFER_TOKEN]
struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 513);
memcpy(input->data, tokens.data(), tokens.size()*sizeof(int32_t));
input = ggml_set_1d(ctx, input, semantic_history, tokens.size()*sizeof(int32_t));
*((float *) ((char *) ggml_get_data(input) + 512*input->nb[0])) = SEMANTIC_INFER_TOKEN;

ggml_build_forward_expand(&gf, input);
ggml_graph_compute_with_ctx(ctx, &gf, 1);

memcpy(out.data(), input->data, 513*sizeof(int32_t));
}

bark_sequence bark_forward_text_encoder(
const bark_sequence & tokens,
const gpt_model model,
bark_sequence & tokens,
struct bark_history_prompts * history_prompts,
const std::string & voice,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
Expand All @@ -1499,7 +1544,10 @@ bark_sequence bark_forward_text_encoder(

float eos_p = 0;

bark_sequence input = tokens;
// build input token sequence
bark_sequence input;
bark_get_input_sequence(history_prompts, tokens, input, voice);

std::vector<float> logits;

// dry run to estimate mem_per_token
Expand Down Expand Up @@ -1923,15 +1971,16 @@ bool bark_generate_audio(
printf("\n");

bark_sequence semantic_tokens = bark_forward_text_encoder(
tokens, model.text_model, voice, rng, n_threads, temp, min_eos_p);
tokens, &model.history_prompts, voice, model.text_model, rng, n_threads, temp, min_eos_p);
printf("\n");

bark_codes coarse_tokens = bark_forward_coarse_encoder(
semantic_tokens, model.coarse_model, voice, rng, n_threads, temp, max_coarse_history, sliding_window_size);
semantic_tokens, history_prompt, model.coarse_model, rng, n_threads, temp,
max_coarse_history, sliding_window_size);
printf("\n");

bark_codes fine_tokens = bark_forward_fine_encoder(
coarse_tokens, model.fine_model, voice, rng, n_threads, fine_temp);
coarse_tokens, history_prompt, model.fine_model, rng, n_threads, fine_temp);
printf("\n");

audio_arr_t audio_arr = bark_forward_encodec(fine_tokens, model.codec_model);
Expand Down Expand Up @@ -1987,6 +2036,6 @@ void bark_print_usage(char ** argv, const bark_params & params) {
fprintf(stderr, " -o FNAME, --outwav FNAME\n");
fprintf(stderr, " output generated wav (default: %s)\n", params.dest_wav_path.c_str());
fprintf(stderr, " -v VOICE, --voice VOICE\n");
fprintf(stderr, " custom voice (default: none)\n", params.voice.c_str());
fprintf(stderr, " custom voice (default: none)\n");
fprintf(stderr, "\n");
}
7 changes: 5 additions & 2 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,18 @@ bool bark_generate_audio(
const std::string & voice);

bark_sequence bark_forward_text_encoder(
const bark_sequence & tokens,
const gpt_model model,
bark_sequence & tokens,
struct bark_history_prompts * history_prompts,
const std::string & voice,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const float min_eos_p);

bark_codes bark_forward_coarse_encoder(
const bark_sequence & tokens,
struct bark_voice * history_prompt,
const gpt_model model,
const std::string & voice,
std::mt19937 & rng,
Expand All @@ -256,6 +258,7 @@ bark_codes bark_forward_coarse_encoder(

bark_codes bark_forward_fine_encoder(
const bark_codes & tokens,
struct bark_voice * history_prompt,
const gpt_model model,
const std::string & voice,
std::mt19937 & rng,
Expand Down
Loading