Skip to content

Commit

Permalink
FEAT Add argmax sampling (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Aug 1, 2023
1 parent 8222128 commit 99f1544
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 51 deletions.
95 changes: 53 additions & 42 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,57 +1166,71 @@ bool gpt_eval(
return true;
}

bark_vocab::id gpt_sample(
const std::vector<float>& logits,
double temp,
bark_vocab::id gpt_multinomial_sample(
std::vector<float> & logits,
std::mt19937 & rng,
float temp,
float * eos_p) {
int n_logits = logits.size();

std::vector<std::pair<double, bark_vocab::id>> logits_id;
logits_id.reserve(n_logits);

{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}

double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
for (int i = 0; i < n_logits; ++i)
logits[i] /= temp;

// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
// for numerical stability
float maxl = -INFINITY;
for (const auto & l : logits)
maxl = std::max(maxl, l);

double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
// softmax
float sum = 0.0;
for (auto & l : logits) {
l = exp(l - maxl);
sum += l;
}

// normalize the probs
for (auto & p : probs) {
p /= sum;
}
for (auto & l : logits)
l /= sum;

std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
std::discrete_distribution<> dist(logits.begin(), logits.end());
int next = dist(rng);

// likelihood of EOS token
if (eos_p)
*eos_p = probs.back();
*eos_p = logits.back();

return logits_id[idx].second;
return next;
}

bark_vocab::id gpt_argmax_sample(std::vector<float> & logits) {
int n_logits = logits.size();

int next = 0;
float maxl = -INFINITY;

for (int i = 0; i < n_logits; i++) {
if (logits[i] > maxl) {
maxl = logits[i];
next = i;
}
}

return next;
}

bark_vocab::id gpt_sample(
std::vector<float> & logits,
std::mt19937 & rng,
float temp,
float * eos_p) {
if (temp == 0.0f)
return gpt_argmax_sample(logits);
return gpt_multinomial_sample(logits, rng, temp, eos_p);
}

std::vector<bark_vocab::id> bark_forward_text_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand All @@ -1229,8 +1243,6 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
std::vector<bark_vocab::id> input = tokens;
std::vector<float> logits;

std::mt19937 rng(0);

// dry run to estimate mem_per_token
size_t mem_per_token = 0;
gpt_eval(model, n_threads, 0, false, { 0, 1, 2, 3 }, logits, mem_per_token);
Expand All @@ -1251,7 +1263,7 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(

input.clear();

bark_vocab::id next = gpt_sample(logits, temp, rng, &eos_p);
bark_vocab::id next = gpt_sample(logits, rng, temp, &eos_p);

if (early_stop && ((next == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p)))
break;
Expand All @@ -1271,6 +1283,7 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand All @@ -1280,8 +1293,6 @@ std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
std::vector<std::vector<bark_vocab::id>> out_coarse(N_COARSE_CODEBOOKS);
std::vector<bark_vocab::id> out;

std::mt19937 rng(0);

float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS;
int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio);

Expand Down Expand Up @@ -1340,7 +1351,7 @@ std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
int end_ix = SEMANTIC_VOCAB_SIZE + (2 - is_major) * CODEBOOK_SIZE;
std::vector<float> relevant_logits(logits.begin() + start_ix, logits.begin() + end_ix);

bark_vocab::id next = gpt_sample(relevant_logits, temp, rng, NULL);
bark_vocab::id next = gpt_sample(relevant_logits, rng, temp, NULL);
next += start_ix;

input_in.push_back(next);
Expand Down Expand Up @@ -1432,11 +1443,11 @@ bool bark_generate_audio(

// encode text (text model)
std::vector<bark_vocab::id> out_semantic = bark_forward_text_encoder(
tokens, model.text_model, n_threads, temp, early_stop, min_eos_p);
tokens, model.text_model, rng, n_threads, temp, early_stop, min_eos_p);

// coarse encoding (coarse model)
std::vector<std::vector<bark_vocab::id>> out_coarse = bark_forward_coarse_encoder(
out_semantic, model.coarse_model, n_threads, temp, early_stop, min_eos_p, max_coarse_history, sliding_window_size);
out_semantic, model.coarse_model, rng, n_threads, temp, early_stop, min_eos_p, max_coarse_history, sliding_window_size);

// fine encoding (fine model)
std::vector<std::vector<bark_vocab::id>> out_fine;
Expand Down Expand Up @@ -1492,7 +1503,7 @@ bool bark_generate_audio(
std::vector<float> relevant_logits = logits[i];
relevant_logits.resize(CODEBOOK_SIZE);

bark_vocab::id sampled_id = gpt_sample(relevant_logits, fine_temp, rng, NULL);
bark_vocab::id sampled_id = gpt_sample(relevant_logits, rng, fine_temp, NULL);
in_buffer[nn][rel_start_fill_ix+i] = sampled_id;
}
}
Expand Down
18 changes: 10 additions & 8 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,20 @@ bool gpt_eval(
size_t & mem_per_token);

bark_vocab::id gpt_sample(
const std::vector<float>& logits,
double temp,
std::mt19937 & rng,
std::vector<float> & logits,
std::mt19937 & rng,
float temp,
float * eos_p);

bool bark_model_load(const std::string & dirname, bark_model & model);

bool bark_vocab_load(const std::string& fname, bark_vocab& vocab, int32_t expected_size);
bool bark_vocab_load(const std::string & fname, bark_vocab& vocab, int32_t expected_size);

void bert_tokenize(
const bark_vocab& vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
const bark_vocab& vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
int32_t n_max_tokens);

bool bark_generate_audio(
Expand All @@ -150,6 +150,7 @@ bool bark_generate_audio(
std::vector<bark_vocab::id> bark_forward_text_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand All @@ -158,6 +159,7 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand Down
2 changes: 1 addition & 1 deletion examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int main() {
printf("\n");

// forward pass
const std::string prompt = "hello world";
const std::string prompt = "this is an audio";
{
const int64_t t_eval_us_start = ggml_time_us();

Expand Down

0 comments on commit 99f1544

Please sign in to comment.