Skip to content

Commit

Permalink
FIX Tokenizer (ct'd) (PABannier#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Aug 11, 2023
1 parent 3b56361 commit 9965598
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
14 changes: 8 additions & 6 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,13 +460,15 @@ void word_piece(std::string word, bark_sequence & tokens, const bark_vocab & voc
tokens.push_back(token_id);
word.erase(0, i);

if (!word.empty()) {
if (word.size() > 0) {
word = "##" + word;
}
}
}

void bert_tokenize(const bark_vocab & vocab, const char * text, bark_sequence & tokens) {
bark_sequence bert_tokenize(const bark_vocab & vocab, const char * text) {
bark_sequence tokens;

std::string str = text;
std::vector<std::string> words;

Expand All @@ -493,6 +495,8 @@ void bert_tokenize(const bark_vocab & vocab, const char * text, bark_sequence &

word_piece(word, tokens, vocab);
}

return tokens;
}

bool fine_gpt_eval(
Expand Down Expand Up @@ -1197,17 +1201,15 @@ bark_vocab::id gpt_sample(std::vector<float> & logits, std::mt19937 & rng, float
bark_sequence bark_tokenize_input(const char * text, const bark_vocab & vocab, int32_t block_size) {
int32_t max_ctx_size = std::min(block_size, 256);

bark_sequence tokens;
bert_tokenize(vocab, text, tokens);

bark_sequence tokens = bert_tokenize(vocab, text);
int n_tokens = tokens.size();

for (int i = 0; i < (int) tokens.size(); i++)
tokens[i] += TEXT_ENCODING_OFFSET;

if (n_tokens < max_ctx_size) {
for (int i = n_tokens; i < max_ctx_size; i++)
tokens[i] = TEXT_PAD_TOKEN;
tokens.push_back(TEXT_PAD_TOKEN);
} else if (n_tokens > max_ctx_size) {
fprintf(stderr, "%s: input sequence is too long (%d > 256), truncating sequence", __func__, n_tokens);
}
Expand Down
2 changes: 1 addition & 1 deletion bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ bool bark_model_load(const std::string & dirname, bark_model & model);

bool bark_vocab_load(const std::string & fname, bark_vocab& vocab, int32_t expected_size);

void bert_tokenize(const bark_vocab & vocab, const char * text, bark_sequence & tokens);
bark_sequence bert_tokenize(const bark_vocab & vocab, const char * text);

bool bark_generate_audio(
bark_model model,
Expand Down
3 changes: 1 addition & 2 deletions tests/test-tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ int main(int argc, char **argv) {
}

for (const auto & test_kv : k_tests()) {
bark_sequence res;
bert_tokenize(vocab, test_kv.first.c_str(), res);
bark_sequence res = bert_tokenize(vocab, test_kv.first.c_str());

bool correct = res.size() == test_kv.second.size();

Expand Down

0 comments on commit 9965598

Please sign in to comment.