Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Add argmax sampling #29

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
added sampling
  • Loading branch information
PABannier committed Aug 1, 2023
commit e0f80f630dbad79ad038530c9b38939a14422685
95 changes: 53 additions & 42 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,57 +1166,71 @@ bool gpt_eval(
return true;
}

bark_vocab::id gpt_sample(
const std::vector<float>& logits,
double temp,
bark_vocab::id gpt_multinomial_sample(
std::vector<float> & logits,
std::mt19937 & rng,
float temp,
float * eos_p) {
int n_logits = logits.size();

std::vector<std::pair<double, bark_vocab::id>> logits_id;
logits_id.reserve(n_logits);

{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}

double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
for (int i = 0; i < n_logits; ++i)
logits[i] /= temp;

// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
// for numerical stability
float maxl = -INFINITY;
for (const auto & l : logits)
maxl = std::max(maxl, l);

double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
// softmax
float sum = 0.0;
for (auto & l : logits) {
l = exp(l - maxl);
sum += l;
}

// normalize the probs
for (auto & p : probs) {
p /= sum;
}
for (auto & l : logits)
l /= sum;

std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
std::discrete_distribution<> dist(logits.begin(), logits.end());
int next = dist(rng);

// likelihood of EOS token
if (eos_p)
*eos_p = probs.back();
*eos_p = logits.back();

return logits_id[idx].second;
return next;
}

bark_vocab::id gpt_argmax_sample(std::vector<float> & logits) {
int n_logits = logits.size();

int next = 0;
float maxl = -INFINITY;

for (int i = 0; i < n_logits; i++) {
if (logits[i] > maxl) {
maxl = logits[i];
next = i;
}
}

return next;
}

bark_vocab::id gpt_sample(
std::vector<float> & logits,
std::mt19937 & rng,
float temp,
float * eos_p) {
if (temp == 0.0f)
return gpt_argmax_sample(logits);
return gpt_multinomial_sample(logits, rng, temp, eos_p);
}

std::vector<bark_vocab::id> bark_forward_text_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand All @@ -1229,8 +1243,6 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
std::vector<bark_vocab::id> input = tokens;
std::vector<float> logits;

std::mt19937 rng(0);

// dry run to estimate mem_per_token
size_t mem_per_token = 0;
gpt_eval(model, n_threads, 0, false, { 0, 1, 2, 3 }, logits, mem_per_token);
Expand All @@ -1251,7 +1263,7 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(

input.clear();

bark_vocab::id next = gpt_sample(logits, temp, rng, &eos_p);
bark_vocab::id next = gpt_sample(logits, rng, temp, &eos_p);

if (early_stop && ((next == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p)))
break;
Expand All @@ -1271,6 +1283,7 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand All @@ -1280,8 +1293,6 @@ std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
std::vector<std::vector<bark_vocab::id>> out_coarse(N_COARSE_CODEBOOKS);
std::vector<bark_vocab::id> out;

std::mt19937 rng(0);

float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS;
int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio);

Expand Down Expand Up @@ -1340,7 +1351,7 @@ std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
int end_ix = SEMANTIC_VOCAB_SIZE + (2 - is_major) * CODEBOOK_SIZE;
std::vector<float> relevant_logits(logits.begin() + start_ix, logits.begin() + end_ix);

bark_vocab::id next = gpt_sample(relevant_logits, temp, rng, NULL);
bark_vocab::id next = gpt_sample(relevant_logits, rng, temp, NULL);
next += start_ix;

input_in.push_back(next);
Expand Down Expand Up @@ -1432,11 +1443,11 @@ bool bark_generate_audio(

// encode text (text model)
std::vector<bark_vocab::id> out_semantic = bark_forward_text_encoder(
tokens, model.text_model, n_threads, temp, early_stop, min_eos_p);
tokens, model.text_model, rng, n_threads, temp, early_stop, min_eos_p);

// coarse encoding (coarse model)
std::vector<std::vector<bark_vocab::id>> out_coarse = bark_forward_coarse_encoder(
out_semantic, model.coarse_model, n_threads, temp, early_stop, min_eos_p, max_coarse_history, sliding_window_size);
out_semantic, model.coarse_model, rng, n_threads, temp, early_stop, min_eos_p, max_coarse_history, sliding_window_size);

// fine encoding (fine model)
std::vector<std::vector<bark_vocab::id>> out_fine;
Expand Down Expand Up @@ -1492,7 +1503,7 @@ bool bark_generate_audio(
std::vector<float> relevant_logits = logits[i];
relevant_logits.resize(CODEBOOK_SIZE);

bark_vocab::id sampled_id = gpt_sample(relevant_logits, fine_temp, rng, NULL);
bark_vocab::id sampled_id = gpt_sample(relevant_logits, rng, fine_temp, NULL);
in_buffer[nn][rel_start_fill_ix+i] = sampled_id;
}
}
Expand Down
18 changes: 10 additions & 8 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,20 @@ bool gpt_eval(
size_t & mem_per_token);

bark_vocab::id gpt_sample(
const std::vector<float>& logits,
double temp,
std::mt19937 & rng,
std::vector<float> & logits,
std::mt19937 & rng,
float temp,
float * eos_p);

bool bark_model_load(const std::string & dirname, bark_model & model);

bool bark_vocab_load(const std::string& fname, bark_vocab& vocab, int32_t expected_size);
bool bark_vocab_load(const std::string & fname, bark_vocab& vocab, int32_t expected_size);

void bert_tokenize(
const bark_vocab& vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
const bark_vocab& vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
int32_t n_max_tokens);

bool bark_generate_audio(
Expand All @@ -150,6 +150,7 @@ bool bark_generate_audio(
std::vector<bark_vocab::id> bark_forward_text_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand All @@ -158,6 +159,7 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
std::mt19937 & rng,
const int n_threads,
const float temp,
const bool early_stop,
Expand Down
2 changes: 1 addition & 1 deletion examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int main() {
printf("\n");

// forward pass
const std::string prompt = "hello world";
const std::string prompt = "this is an audio";
{
const int64_t t_eval_us_start = ggml_time_us();

Expand Down