From e0f80f630dbad79ad038530c9b38939a14422685 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 1 Aug 2023 12:56:16 +0200 Subject: [PATCH] added sampling --- bark.cpp | 95 ++++++++++++++++++++++++++--------------------- bark.h | 18 +++++---- examples/main.cpp | 2 +- 3 files changed, 64 insertions(+), 51 deletions(-) diff --git a/bark.cpp b/bark.cpp index 004ef21..5b137de 100644 --- a/bark.cpp +++ b/bark.cpp @@ -1166,57 +1166,71 @@ bool gpt_eval( return true; } -bark_vocab::id gpt_sample( - const std::vector& logits, - double temp, +bark_vocab::id gpt_multinomial_sample( + std::vector & logits, std::mt19937 & rng, + float temp, float * eos_p) { int n_logits = logits.size(); - std::vector> logits_id; - logits_id.reserve(n_logits); - - { - const double scale = 1.0/temp; - for (int i = 0; i < n_logits; ++i) { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); - } - } - - double maxl = -INFINITY; - for (const auto & kv : logits_id) { - maxl = std::max(maxl, kv.first); - } + for (int i = 0; i < n_logits; ++i) + logits[i] /= temp; - // compute probs for the top K tokens - std::vector probs; - probs.reserve(logits_id.size()); + // for numerical stability + float maxl = -INFINITY; + for (const auto & l : logits) + maxl = std::max(maxl, l); - double sum = 0.0; - for (const auto & kv : logits_id) { - double p = exp(kv.first - maxl); - probs.push_back(p); - sum += p; + // softmax + float sum = 0.0; + for (auto & l : logits) { + l = exp(l - maxl); + sum += l; } - // normalize the probs - for (auto & p : probs) { - p /= sum; - } + for (auto & l : logits) + l /= sum; - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); + std::discrete_distribution<> dist(logits.begin(), logits.end()); + int next = dist(rng); // likelihood of EOS token if (eos_p) - *eos_p = probs.back(); + *eos_p = logits.back(); - return logits_id[idx].second; + return next; +} + +bark_vocab::id gpt_argmax_sample(std::vector & logits) { + int n_logits = logits.size(); + + int next = 0; + float maxl = -INFINITY; + + for (int i = 0; i < n_logits; i++) { + if (logits[i] > maxl) { + maxl = logits[i]; + next = i; + } + } + + return next; +} + +bark_vocab::id gpt_sample( + std::vector & logits, + std::mt19937 & rng, + float temp, + float * eos_p) { + if (temp == 0.0f) + return gpt_argmax_sample(logits); + return gpt_multinomial_sample(logits, rng, temp, eos_p); } std::vector bark_forward_text_encoder( const std::vector & tokens, const gpt_model model, + std::mt19937 & rng, const int n_threads, const float temp, const bool early_stop, @@ -1229,8 +1243,6 @@ std::vector bark_forward_text_encoder( std::vector input = tokens; std::vector logits; - std::mt19937 rng(0); - // dry run to estimate mem_per_token size_t mem_per_token = 0; gpt_eval(model, n_threads, 0, false, { 0, 1, 2, 3 }, logits, mem_per_token); @@ -1251,7 +1263,7 @@ std::vector bark_forward_text_encoder( input.clear(); - bark_vocab::id next = gpt_sample(logits, temp, rng, &eos_p); + bark_vocab::id next = gpt_sample(logits, rng, temp, &eos_p); if (early_stop && ((next == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p))) break; @@ -1271,6 +1283,7 @@ std::vector bark_forward_text_encoder( std::vector> bark_forward_coarse_encoder( const std::vector & tokens, const gpt_model model, + std::mt19937 & rng, const int n_threads, const float temp, const bool early_stop, @@ -1280,8 +1293,6 @@ std::vector> bark_forward_coarse_encoder( std::vector> out_coarse(N_COARSE_CODEBOOKS); std::vector out; - std::mt19937 rng(0); - float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS; int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio); @@ -1340,7 +1351,7 @@ std::vector> bark_forward_coarse_encoder( int end_ix = SEMANTIC_VOCAB_SIZE + (2 - is_major) * CODEBOOK_SIZE; std::vector relevant_logits(logits.begin() + start_ix, logits.begin() + end_ix); - bark_vocab::id next = gpt_sample(relevant_logits, temp, rng, NULL); + bark_vocab::id next = gpt_sample(relevant_logits, rng, temp, NULL); next += start_ix; input_in.push_back(next); @@ -1432,11 +1443,11 @@ bool bark_generate_audio( // encode text (text model) std::vector out_semantic = bark_forward_text_encoder( - tokens, model.text_model, n_threads, temp, early_stop, min_eos_p); + tokens, model.text_model, rng, n_threads, temp, early_stop, min_eos_p); // coarse encoding (coarse model) std::vector> out_coarse = bark_forward_coarse_encoder( - out_semantic, model.coarse_model, n_threads, temp, early_stop, min_eos_p, max_coarse_history, sliding_window_size); + out_semantic, model.coarse_model, rng, n_threads, temp, early_stop, min_eos_p, max_coarse_history, sliding_window_size); // fine encoding (fine model) std::vector> out_fine; @@ -1492,7 +1503,7 @@ bool bark_generate_audio( std::vector relevant_logits = logits[i]; relevant_logits.resize(CODEBOOK_SIZE); - bark_vocab::id sampled_id = gpt_sample(relevant_logits, fine_temp, rng, NULL); + bark_vocab::id sampled_id = gpt_sample(relevant_logits, rng, fine_temp, NULL); in_buffer[nn][rel_start_fill_ix+i] = sampled_id; } } diff --git a/bark.h b/bark.h index 370293f..8738658 100644 --- a/bark.h +++ b/bark.h @@ -125,20 +125,20 @@ bool gpt_eval( size_t & mem_per_token); bark_vocab::id gpt_sample( - const std::vector& logits, - double temp, - std::mt19937 & rng, + std::vector & logits, + std::mt19937 & rng, + float temp, float * eos_p); bool bark_model_load(const std::string & dirname, bark_model & model); -bool bark_vocab_load(const std::string& fname, bark_vocab& vocab, int32_t expected_size); +bool bark_vocab_load(const std::string & fname, bark_vocab& vocab, int32_t expected_size); void bert_tokenize( - const bark_vocab& vocab, - const char * text, - int32_t * tokens, - int32_t * n_tokens, + const bark_vocab& vocab, + const char * text, + int32_t * tokens, + int32_t * n_tokens, int32_t n_max_tokens); bool bark_generate_audio( @@ -150,6 +150,7 @@ bool bark_generate_audio( std::vector bark_forward_text_encoder( const std::vector & tokens, const gpt_model model, + std::mt19937 & rng, const int n_threads, const float temp, const bool early_stop, @@ -158,6 +159,7 @@ std::vector bark_forward_text_encoder( std::vector> bark_forward_coarse_encoder( const std::vector & tokens, const gpt_model model, + std::mt19937 & rng, const int n_threads, const float temp, const bool early_stop, diff --git a/examples/main.cpp b/examples/main.cpp index ed35418..82a06ee 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -25,7 +25,7 @@ int main() { printf("\n"); // forward pass - const std::string prompt = "hello world"; + const std::string prompt = "this is an audio"; { const int64_t t_eval_us_start = ggml_time_us();