diff --git a/bark.cpp b/bark.cpp index 7f23dcb..f4bd34d 100644 --- a/bark.cpp +++ b/bark.cpp @@ -1227,6 +1227,38 @@ bark_vocab::id gpt_sample( return gpt_multinomial_sample(logits, rng, temp, eos_p); } +bark_sequence bark_tokenize_input( + const char * text, + const bark_vocab & vocab, + const int32_t block_size) { + int32_t max_ctx_size = std::min(block_size, 256); + + int32_t n_tokens; + bark_sequence tokens(max_ctx_size); + + bert_tokenize(vocab, text, tokens.data(), &n_tokens, max_ctx_size); + for (int i = 0; i < (int) tokens.size(); i++) + tokens[i] += TEXT_ENCODING_OFFSET; + + if (n_tokens < max_ctx_size) { + for (int i = n_tokens; i < max_ctx_size; i++) + tokens[i] = TEXT_PAD_TOKEN; + } else if (n_tokens > max_ctx_size) { + fprintf(stderr, "%s: input sequence is too long (%d > 256), truncating sequence", __func__, n_tokens); + } + + tokens.resize(max_ctx_size); + + // semantic history + for (int i = 0; i < 256; i++) + tokens.push_back(SEMANTIC_PAD_TOKEN); + tokens.push_back(SEMANTIC_INFER_TOKEN); + + assert(tokens.size() == 256 + 256 + 1); + + return tokens; +} + bark_sequence bark_forward_text_encoder( const bark_sequence & tokens, const gpt_model model, @@ -1481,34 +1513,9 @@ bool bark_generate_audio( std::mt19937 rng(seed); - // tokenize text (bert tokenizer) - { - // max bark length: 256 - int32_t max_ctx_size = std::min(model.text_model.hparams.block_size, 256); - int32_t n_tokens; - tokens.resize(max_ctx_size); - - bert_tokenize(vocab, text, tokens.data(), &n_tokens, max_ctx_size); - for (int i = 0; i < (int) tokens.size(); i++) - tokens[i] += TEXT_ENCODING_OFFSET; - - if (n_tokens < max_ctx_size) { - for (int i = n_tokens; i < max_ctx_size; i++) - tokens[i] = TEXT_PAD_TOKEN; - } else if (n_tokens > max_ctx_size) { - fprintf(stderr, "%s: input sequence is too long (%d > 256), truncating sequence", __func__, n_tokens); - } - - tokens.resize(max_ctx_size); - - // semantic history - for (int i = 0; i < 256; i++) - tokens.push_back(SEMANTIC_PAD_TOKEN); - - tokens.push_back(SEMANTIC_INFER_TOKEN); - - assert(tokens.size() == 256 + 256 + 1); - } + // bert tokenizer + const int32_t block_size = model.text_model.hparams.block_size; + bark_sequence tokens = bark_tokenize_input(text, vocab, block_size); printf("%s: prompt: '%s'\n", __func__, text); printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, tokens.size());