Skip to content

Commit

Permalink
fix bugs with tokenizer (PABannier#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Aug 11, 2023
1 parent 9965598 commit 8ec7c72
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 38 deletions.
73 changes: 41 additions & 32 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,35 +443,19 @@ std::string strip_accents(const std::string &in_str) {
return out_str;
}

void word_piece(std::string word, bark_sequence & tokens, const bark_vocab & vocab) {
auto * token_map = &vocab.token_to_id;
int i = word.size();

while (word.size() > 0) {
while (i > 0 && token_map->find(word.substr(0, i)) == token_map->end()) { --i; }

if (i == 0) {
tokens.push_back(101); // [UNK] token
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.c_str());
return;
}

bark_vocab::id token_id = token_map->find(word.substr(0, i))->second;
tokens.push_back(token_id);
word.erase(0, i);

if (word.size() > 0) {
word = "##" + word;
}
}
}

bark_sequence bert_tokenize(const bark_vocab & vocab, const char * text) {
bark_sequence tokens;

void bert_tokenize(
const bark_vocab & vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
int32_t n_max_tokens) {
std::string str = text;
std::vector<std::string> words;

int32_t t = 0;

auto * token_map = &vocab.token_to_id;

// split the text into words
{
str = strip_accents(text);
Expand All @@ -493,10 +477,34 @@ bark_sequence bert_tokenize(const bark_vocab & vocab, const char * text) {
if (word.size() == 0)
continue;

word_piece(word, tokens, vocab);
}
std::string prefix = "";
int i = 0;
int n = word.size();

loop:
while (i < n) {
if (t >= n_max_tokens - 1)
break;
int j = n;
while (j > i) {
auto it = token_map->find(prefix + word.substr(i, j - i));
if (it != token_map->end()) {
tokens[t++] = it->second;
i = j;
prefix = "##";
goto loop;
}
--j;
}
if (j == i) {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
prefix = "##";
++i;
}
}
}

return tokens;
*n_tokens = t;
}

bool fine_gpt_eval(
Expand Down Expand Up @@ -1200,16 +1208,17 @@ bark_vocab::id gpt_sample(std::vector<float> & logits, std::mt19937 & rng, float

bark_sequence bark_tokenize_input(const char * text, const bark_vocab & vocab, int32_t block_size) {
int32_t max_ctx_size = std::min(block_size, 256);
int32_t n_tokens;

bark_sequence tokens = bert_tokenize(vocab, text);
int n_tokens = tokens.size();
bark_sequence tokens(max_ctx_size);
bert_tokenize(vocab, text, tokens.data(), &n_tokens, max_ctx_size);

for (int i = 0; i < (int) tokens.size(); i++)
tokens[i] += TEXT_ENCODING_OFFSET;

if (n_tokens < max_ctx_size) {
for (int i = n_tokens; i < max_ctx_size; i++)
tokens.push_back(TEXT_PAD_TOKEN);
tokens[i] = TEXT_PAD_TOKEN;
} else if (n_tokens > max_ctx_size) {
fprintf(stderr, "%s: input sequence is too long (%d > 256), truncating sequence", __func__, n_tokens);
}
Expand Down
7 changes: 6 additions & 1 deletion bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,12 @@ bool bark_model_load(const std::string & dirname, bark_model & model);

bool bark_vocab_load(const std::string & fname, bark_vocab& vocab, int32_t expected_size);

bark_sequence bert_tokenize(const bark_vocab & vocab, const char * text);
void bert_tokenize(
const bark_vocab & vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
int32_t n_max_tokens);

bool bark_generate_audio(
bark_model model,
Expand Down
5 changes: 1 addition & 4 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@ int main() {
printf("\n");

// forward pass
const std::string prompt = "this is an audio";
const std::string prompt = "hi! i'm john and i'm a software engineer.";
{
const int64_t t_eval_us_start = ggml_time_us();

// call to generate audio
bark_generate_audio(model, model.vocab, prompt.data(), 4);

t_eval_us = ggml_time_us() - t_eval_us_start;
}

Expand Down
6 changes: 5 additions & 1 deletion tests/test-tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@ int main(int argc, char **argv) {
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());

bark_vocab vocab;
int max_ctx_size = 256;

if(!bark_vocab_load(fname, vocab, 119547)) {
fprintf(stderr, "%s: invalid vocab file '%s'\n", __func__, fname.c_str());
return 1;
}

for (const auto & test_kv : k_tests()) {
bark_sequence res = bert_tokenize(vocab, test_kv.first.c_str());
bark_sequence res(test_kv.first.size());
int n_tokens;
bert_tokenize(vocab, test_kv.first.c_str(), res.data(), &n_tokens, max_ctx_size);
res.resize(n_tokens);

bool correct = res.size() == test_kv.second.size();

Expand Down

0 comments on commit 8ec7c72

Please sign in to comment.