diff --git a/.gitignore b/.gitignore index 2c0fe21..32fdd66 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ bark encodec main tests/test-tokenizer +tests/test-text-encoder *.o *.plist diff --git a/Makefile b/Makefile index 227693a..05911c3 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main +BUILD_TARGETS = main # Binaries only useful for tests -TEST_TARGETS = tests/test-tokenizer +TEST_TARGETS = tests/test-tokenizer tests/test-text-encoder default: $(BUILD_TARGETS) @@ -315,3 +315,6 @@ tests: $(TEST_TARGETS) tests/test-tokenizer: tests/test-tokenizer.cpp ggml.o bark.o encodec.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + +tests/test-text-encoder: tests/test-text-encoder.cpp ggml.o bark.o encodec.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) diff --git a/bark.cpp b/bark.cpp index f1e11a1..39b57a4 100644 --- a/bark.cpp +++ b/bark.cpp @@ -1160,7 +1160,6 @@ bool gpt_eval( if (mem_per_token == 0) { mem_per_token = ggml_used_mem(ctx0)/N; } - //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); ggml_free(ctx0); @@ -1215,6 +1214,60 @@ bark_vocab::id gpt_sample( return logits_id[idx].second; } +std::vector bark_forward_text_encoder( + const std::vector & tokens, + const gpt_model model, + const int n_threads, + const float temp, + const bool early_stop, + const float min_eos_p) { + + std::vector out; + int n_past = 0; + float eos_p = 0; + + std::vector input = tokens; + std::vector logits; + + std::mt19937 rng(0); + + // dry run to estimate mem_per_token + size_t mem_per_token = 0; + gpt_eval(model, n_threads, 0, false, { 0, 1, 2, 3 }, logits, mem_per_token); + + for (int i = 0; i < 768; i++) { + const bool merge_ctx = i == 0; + gpt_eval(model, n_threads, n_past, merge_ctx, input, logits, mem_per_token); + + float logits_pad_token = logits[SEMANTIC_PAD_TOKEN]; + logits.resize(SEMANTIC_VOCAB_SIZE); + + if (early_stop) + logits.push_back(logits[logits_pad_token]); + + n_past += input.size(); + if (i == 0) + n_past -= 256; // first step, context are merged + + input.clear(); + + bark_vocab::id sampled_id = gpt_sample(logits, temp, rng, &eos_p); + + if (early_stop && ((sampled_id == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p))) + break; + + input.push_back(sampled_id); + out.push_back(sampled_id); + + printf("%d ", sampled_id); + fflush(stdout); + } + + printf("\n\nsemantic sequence length: %zu\n\n", out.size()); + + return out; +} + bool bark_generate_audio( bark_model model, const bark_vocab& vocab, @@ -1227,17 +1280,15 @@ bool bark_generate_audio( const int seed = 0; // const float top_p = 0.2; - const float temp = 0.7; + const float temp = 1.0; const float fine_temp = 0.5; - const int early_stop = true; + const bool early_stop = true; const int sliding_window_size = 60; const int max_coarse_history = 630; - // in the original implementation, min_eos_p=0.2, yet for bark.cpp this seems too - // high and this generates overly long sequence. - const float min_eos_p = 0.15; + const float min_eos_p = 0.2; std::mt19937 rng(seed); @@ -1279,47 +1330,8 @@ bool bark_generate_audio( printf("\n\n"); // encode text (text model) - std::vector out_semantic; - { - int n_past = 0; - float eos_p = 0; - - std::vector input = tokens; - std::vector logits; - - // dry run to estimate mem_per_token - size_t mem_per_token = 0; - gpt_eval(model.text_model, n_threads, 0, false, { 0, 1, 2, 3 }, logits, mem_per_token); - - for (int i = 0; i < 768; i++) { - const bool merge_ctx = i == 0; - gpt_eval(model.text_model, n_threads, n_past, merge_ctx, input, logits, mem_per_token); - - float logits_pad_token = logits[SEMANTIC_PAD_TOKEN]; - logits.resize(SEMANTIC_VOCAB_SIZE); - - if (early_stop) - logits.push_back(logits[logits_pad_token]); - - n_past += input.size(); - if (i == 0) - n_past -= 256; // first step, context are merged - - input.clear(); - - bark_vocab::id sampled_id = gpt_sample(logits, temp, rng, &eos_p); - input.push_back(sampled_id); - out_semantic.push_back(sampled_id); - - printf("%d ", sampled_id); - fflush(stdout); - - if (early_stop && ((sampled_id == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p))) - break; - } - - printf("\n\nsemantic sequence length: %zu\n\n", out_semantic.size()); - } + std::vector out_semantic = bark_forward_text_encoder( + tokens, model.text_model, n_threads, temp, early_stop, min_eos_p); // coarse encoding (coarse model) std::vector> out_coarse(N_COARSE_CODEBOOKS); diff --git a/bark.h b/bark.h index f03cb2f..93f4968 100644 --- a/bark.h +++ b/bark.h @@ -1,6 +1,7 @@ #include "encodec.h" #include +#include #include #define CLS_TOKEN_ID 101 @@ -114,6 +115,21 @@ struct bark_model { bool gpt_model_load(const std::string& fname, gpt_model& model); +bool gpt_eval( + const gpt_model & model, + const int n_threads, + const int n_past, + const bool merge_ctx, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token); + +bark_vocab::id gpt_sample( + const std::vector& logits, + double temp, + std::mt19937 & rng, + float * eos_p); + bool bark_model_load(const std::string & dirname, bark_model & model); bool bark_vocab_load(const std::string& fname, bark_vocab& vocab, int32_t expected_size); @@ -130,3 +146,11 @@ bool bark_generate_audio( const bark_vocab& vocab, const char * text, const int n_threads); + +std::vector bark_forward_text_encoder( + const std::vector & tokens, + const gpt_model model, + const int n_threads, + const float temp, + const bool early_stop, + const float min_eos_p); diff --git a/examples/main.cpp b/examples/main.cpp index af2b089..ed35418 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -25,7 +25,7 @@ int main() { printf("\n"); // forward pass - const std::string prompt = "This is an audio"; + const std::string prompt = "hello world"; { const int64_t t_eval_us_start = ggml_time_us(); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f859a36..0c4e2ae 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,3 +7,4 @@ function(bark_add_test source) endfunction() bark_add_test(test-tokenizer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../ggml_weights/ggml_vocab.bin) +bark_add_test(test-text-encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../ggml_weights/ggml_weights_text.bin) diff --git a/tests/test-text-encoder.cpp b/tests/test-text-encoder.cpp new file mode 100644 index 0000000..a06214e --- /dev/null +++ b/tests/test-text-encoder.cpp @@ -0,0 +1,92 @@ +#include "bark.h" + +#include +#include +#include +#include +#include + +static const std::map, std::vector> & k_tests() +{ + static const std::vector seq1 = { 71742, 20181, 21404 }; + static const std::vector seq2 = { 20579, 20172, 20199, 33733 }; + static const std::vector seq3 = { 21113, 35307, 10165, 62571, 10165, 23622, 20236, 20959, 52867 }; + + static const std::vector ans1 = { 3264, 6121, 6414, 7799, 6121, 1907, 1888, 206, 1888, 1888, 6143, 2131, 728, 6328, 3393, 5990, 2992, 8837, 206, 7799, 2533, 7374, 2992, 4059, 5990, 5990, 5236, 5939, 10, 2137, 2137, 6021, 176, 176, 2584, 499, 2382, 499, 8051, 4218, 8051, 8909, 306, 6804, 6804, 292, 292, 675, 3927, 3819, 624, 6583, 8843, 6583, 6583, 9628, 3062, 6652, 6652, 6652, 4792, 6608, 4792, 6608, 9025, 8534, 709, 7431, 7431, 9693, 8858, 8858, 3820, 8858, 682, 4076, 8996, 4909, 5682, 6139, 6139, 9133, 445, 971, 7542, 4564, 9931, 9931, 785, 785, 157, 5897, 5897, 9527, 1233, 138, 131, 10, 266, 266, 1572, 1572, 206, 206, 3533, 206, 4874, 1444, 3533, 206, 206, 7397, 206, 3252, 206, 2314, 91, 206, 7567, 841, 5346, 3252, 206, 841, 3366, 517, 517, 3252, 344, 344, 1278, 3950, 57, 57, 597, 7160, 121, 7334, 631, 292, 41, 41, 8944, 1991, 1408, 1408, 1408, 1462, 3, 166, 8745, 17, 2332, 1574, 7443, 50, 17, 27, 429, 9225, 713, 4099, 4099, 4099, 75, 555, 5932, 8870, 7627, 7627, 5661, 3088, 26, 288, 262 }; + static const std::vector ans2 = { 3205, 6179, 7731, 6972, 5722, 602, 441, 125, 147, 991, 1573, 402, 402, 6774, 1913, 8020, 8572, 8572, 1722, 5681, 1133, 4694, 1133, 7517, 9575, 8125, 5905, 6486, 1797, 6486, 5138, 5138, 4150, 2630, 2879, 59, 28, 28, 385, 1741, 4042, 9898, 302, 9600, 7231, 5673, 5475, 321, 171, 321, 164, 1025, 4681, 6202, 6752, 8288, 6747, 7656, 9804, 9804, 2411, 178, 50, 441, 6401, 5899, 79, 6511, 6511, 9629, 6511, 6154, 2224, 2224, 73, 73, 9814, 6303, 1997, 1997, 7396, 8062, 825, 441}; + static const std::vector ans3 = { 3946, 8514, 7741, 9262, 5153, 4400, 4509, 512, 1136, 4631, 8486, 4631, 3954, 7234, 993, 4412, 993, 9161, 332, 8209, 5565, 4224, 4344, 6152, 6152, 2704, 2285, 4438, 232, 131, 10, 5038, 2430, 59, 28, 28, 28, 28, 1310, 4449, 5920, 9449, 2002, 9693, 7939, 4049, 4049, 6059, 210, 100, 10, 10, 282, 3968, 988, 9790, 1728, 2587, 4405, 2948, 232, 232, 100, 3621, 8680, 417, 10, 2595, 7352, 9725, 6445, 2428, 41, 41, 10, 41, 9261, 4212, 3963, 6261, 8210, 9588, 934, 441, 1025, 2875, 8558, 6968, 116, 41, 41, 7789, 5721, 5721, 267, 2116, 579, 100, 1133, 3446, 2599, 7503, 3390, 3390, 4485, 657, 1385, 1385, 7691, 7557, 5272, 8887, 10, 6619, 3592, 6394, 5272, 5272, 8887, 1841, 602, 441, 217, 4542, 5861, 5861, 3803, 4542, 4542, 4542, 6675, 7204, 131, 100, 790, 2832, 266, 6115, 4209, 1739, 1739, 1444, 8659, 1739, 1739, 1739, 1133, 1739, 1739, 2556, 2556, 413, 413, 10, 3373, 7966, 2330, 1588, 409, 2942, 59, 28, 28, 28, 10, 28, 3160, 9569, 5920, 5887, 9693, 6290, 3458, 1242, 50, 210, 2977, 1433, 1433, 6150, 6150, 1136, 6413, 9693, 3441, 9598, 9061, 7949, 9137, 5615, 131, 100, 652, 7863, 7344, 8899, 7765, 50, 10, 100, 7399, 9915, 7557, 4509, 8486, 6264, 6133, 6133, 6133, 6619, 5210, 5210, 9629, 2555, 2339, 9486, 1425, 2762, 2466, 1079, 10 }; + + static std::map, std::vector> _k_tests = { + { seq1, ans1 }, // hello world + { seq2, ans2 }, // this is an audio + { seq3, ans3 }, // You cannot, sir, take from me anything + }; + return _k_tests; +}; + +std::vector pad_input(const std::vector & input) { + int original_sz = input.size(); + std::vector pad(input.begin(), input.end()); + pad.resize(513); + + for (int i = original_sz; i < 256; i++) + pad[i] = TEXT_PAD_TOKEN; + for (int i = 256; i < 512; i++) + pad[i] = SEMANTIC_PAD_TOKEN; + pad[512] = SEMANTIC_INFER_TOKEN; + + return pad; +} + +int main(int argc, char** argv) { + if (argc < 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + const std::string fname = argv[1]; + + gpt_model model; + const int n_threads = 4; + const float min_eos_p = 0.2; + + printf("%s: reading bark text model\n", __func__); + if(!gpt_model_load(fname, model)) { + fprintf(stderr, "%s: invalid model file '%s'\n", __func__, fname.c_str()); + return 1; + } + + for (const auto & test_kv : k_tests()) { + std::vector pad = pad_input(test_kv.first); + std::vector res = bark_forward_text_encoder( + pad, model, n_threads, 1.0f, true, min_eos_p); + + bool correct = res.size() == test_kv.second.size(); + + for (int i = 0; i < (int) res.size() && correct; ++i) { + if (res[i] != test_kv.second[i]) { + correct = false; + } + } + + if (!correct) { + fprintf(stderr, "%s : failed test \n", __func__); + fprintf(stderr, "%s : expected tokens (n=%zu): ", __func__, test_kv.second.size()); + for (const auto & t : test_kv.second) { + fprintf(stderr, "%d ", t); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s : got tokens (n=%zu): ", __func__, res.size()); + for (const auto & t : res) { + fprintf(stderr, "%d ", t); + } + fprintf(stderr, "\n"); + + return 3; + } + } + + fprintf(stderr, "%s : tests passed successfully.\n", __func__); + + return 0; +}