Skip to content

Commit

Permalink
moved bert tokenizer into its own function
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Aug 1, 2023
1 parent 7b93cc5 commit d1eac2c
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,38 @@ bark_vocab::id gpt_sample(
return gpt_multinomial_sample(logits, rng, temp, eos_p);
}

bark_sequence bark_tokenize_input(
const char * text,
const bark_vocab & vocab,
const int32_t block_size) {
int32_t max_ctx_size = std::min(block_size, 256);

int32_t n_tokens;
bark_sequence tokens(max_ctx_size);

bert_tokenize(vocab, text, tokens.data(), &n_tokens, max_ctx_size);
for (int i = 0; i < (int) tokens.size(); i++)
tokens[i] += TEXT_ENCODING_OFFSET;

if (n_tokens < max_ctx_size) {
for (int i = n_tokens; i < max_ctx_size; i++)
tokens[i] = TEXT_PAD_TOKEN;
} else if (n_tokens > max_ctx_size) {
fprintf(stderr, "%s: input sequence is too long (%d > 256), truncating sequence", __func__, n_tokens);
}

tokens.resize(max_ctx_size);

// semantic history
for (int i = 0; i < 256; i++)
tokens.push_back(SEMANTIC_PAD_TOKEN);
tokens.push_back(SEMANTIC_INFER_TOKEN);

assert(tokens.size() == 256 + 256 + 1);

return tokens;
}

bark_sequence bark_forward_text_encoder(
const bark_sequence & tokens,
const gpt_model model,
Expand Down Expand Up @@ -1481,34 +1513,9 @@ bool bark_generate_audio(

std::mt19937 rng(seed);

// tokenize text (bert tokenizer)
{
// max bark length: 256
int32_t max_ctx_size = std::min(model.text_model.hparams.block_size, 256);
int32_t n_tokens;
tokens.resize(max_ctx_size);

bert_tokenize(vocab, text, tokens.data(), &n_tokens, max_ctx_size);
for (int i = 0; i < (int) tokens.size(); i++)
tokens[i] += TEXT_ENCODING_OFFSET;

if (n_tokens < max_ctx_size) {
for (int i = n_tokens; i < max_ctx_size; i++)
tokens[i] = TEXT_PAD_TOKEN;
} else if (n_tokens > max_ctx_size) {
fprintf(stderr, "%s: input sequence is too long (%d > 256), truncating sequence", __func__, n_tokens);
}

tokens.resize(max_ctx_size);

// semantic history
for (int i = 0; i < 256; i++)
tokens.push_back(SEMANTIC_PAD_TOKEN);

tokens.push_back(SEMANTIC_INFER_TOKEN);

assert(tokens.size() == 256 + 256 + 1);
}
// bert tokenizer
const int32_t block_size = model.text_model.hparams.block_size;
bark_sequence tokens = bark_tokenize_input(text, vocab, block_size);

printf("%s: prompt: '%s'\n", __func__, text);
printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, tokens.size());
Expand Down

0 comments on commit d1eac2c

Please sign in to comment.