From 7b93cc5ab4374efd58ef57003206d07a2d836027 Mon Sep 17 00:00:00 2001 From: PAB Date: Tue, 1 Aug 2023 15:04:00 +0200 Subject: [PATCH] ENH Use type aliases for nested `std::vector` types (#31) --- bark.cpp | 48 +++++++++++++++++------------------ bark.h | 15 ++++++----- tests/test-coarse-encoder.cpp | 18 ++++++------- tests/test-fine-encoder.cpp | 18 ++++++------- tests/test-text-encoder.cpp | 24 +++++++++--------- tests/test-tokenizer.cpp | 6 ++--- 6 files changed, 66 insertions(+), 63 deletions(-) diff --git a/bark.cpp b/bark.cpp index fc958ce..7f23dcb 100644 --- a/bark.cpp +++ b/bark.cpp @@ -522,7 +522,7 @@ bool fine_gpt_eval( const gpt_model & model, const int n_threads, const int codebook_ix, - const std::vector> & embd_inp, + const bark_codes & embd_inp, std::vector> & logits, size_t & mem_per_token) { // embd_inp: (n_channels, seq_length) @@ -854,7 +854,7 @@ bool gpt_eval( const int n_threads, const int n_past, const bool merge_ctx, - const std::vector & embd_inp, + const bark_sequence & embd_inp, std::vector & embd_w, size_t & mem_per_token) { int N = embd_inp.size(); @@ -1227,8 +1227,8 @@ bark_vocab::id gpt_sample( return gpt_multinomial_sample(logits, rng, temp, eos_p); } -std::vector bark_forward_text_encoder( - const std::vector & tokens, +bark_sequence bark_forward_text_encoder( + const bark_sequence & tokens, const gpt_model model, std::mt19937 & rng, const int n_threads, @@ -1236,11 +1236,11 @@ std::vector bark_forward_text_encoder( const bool early_stop, const float min_eos_p) { - std::vector out; + bark_sequence out; int n_past = 0; float eos_p = 0; - std::vector input = tokens; + bark_sequence input = tokens; std::vector logits; // dry run to estimate mem_per_token @@ -1280,16 +1280,16 @@ std::vector bark_forward_text_encoder( return out; } -std::vector> bark_forward_coarse_encoder( - const std::vector & tokens, +bark_codes bark_forward_coarse_encoder( + const bark_sequence & tokens, const gpt_model model, std::mt19937 & rng, const int n_threads, const float temp, const int max_coarse_history, const int sliding_window_size) { - std::vector> out_coarse(N_COARSE_CODEBOOKS); - std::vector out; + bark_codes out_coarse(N_COARSE_CODEBOOKS); + bark_sequence out; float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS; int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio); @@ -1302,7 +1302,7 @@ std::vector> bark_forward_coarse_encoder( int n_window_steps = ceilf(static_cast(n_steps) / sliding_window_size); - std::vector input = tokens; + bark_sequence input = tokens; std::vector logits; // dry run to estimate mem_per_token @@ -1312,7 +1312,7 @@ std::vector> bark_forward_coarse_encoder( for(int i = 0; i < n_window_steps; i++) { int semantic_ix = roundf(n_steps / semantic_to_coarse_ratio); - std::vector input_in( + bark_sequence input_in( input.begin() + std::max(semantic_ix-max_semantic_history, 0), input.end() ); @@ -1377,13 +1377,13 @@ std::vector> bark_forward_coarse_encoder( return out_coarse; } -std::vector> bark_forward_fine_encoder( - const std::vector> & tokens, +bark_codes bark_forward_fine_encoder( + const bark_codes & tokens, const gpt_model model, std::mt19937 & rng, const int n_threads, const float temp) { - std::vector> input = tokens; + bark_codes input = tokens; std::vector> logits; size_t mem_per_token = 0; @@ -1394,7 +1394,7 @@ std::vector> bark_forward_fine_encoder( // channel padding for(int i = N_COARSE_CODEBOOKS; i < N_FINE_CODEBOOKS; i++) { - std::vector tmp(original_seq_len, CODEBOOK_SIZE); + bark_sequence tmp(original_seq_len, CODEBOOK_SIZE); input.push_back(tmp); } @@ -1413,23 +1413,23 @@ std::vector> bark_forward_fine_encoder( int n_loops = std::max(0, (int) ceilf((input[0].size() - 1024)/512.f)) + 1; - std::vector> in_arr = input; + bark_codes in_arr = input; for (int n = 0; n < n_loops; n++) { int start_ix = std::min(n * 512, (int) in_arr[0].size() - 1024); int start_fill_ix = std::min(n * 512, (int) in_arr[0].size() - 512); int rel_start_fill_ix = start_fill_ix - start_ix; - std::vector> in_buffer(in_arr.size()); + bark_codes in_buffer(in_arr.size()); for (int ix = 0; ix < (int) in_buffer.size(); ix++) { - std::vector buf(in_arr[ix].begin() + start_ix, in_arr[ix].begin() + start_ix + 1024); + bark_sequence buf(in_arr[ix].begin() + start_ix, in_arr[ix].begin() + start_ix + 1024); in_buffer[ix] = buf; } for (int nn = n_coarse; nn < N_FINE_CODEBOOKS; nn++) { fine_gpt_eval(model, n_threads, nn, in_buffer, logits, mem_per_token); - std::vector predictions(CODEBOOK_SIZE - rel_start_fill_ix); + bark_sequence predictions(CODEBOOK_SIZE - rel_start_fill_ix); for (int i = 0; i < (int) logits.size(); i++) { logits[i].resize(CODEBOOK_SIZE); @@ -1462,7 +1462,7 @@ bool bark_generate_audio( const bark_vocab& vocab, const char * text, const int n_threads) { - std::vector tokens; + bark_sequence tokens; // TODO move into params // const int top_k = 10; @@ -1519,15 +1519,15 @@ bool bark_generate_audio( printf("\n\n"); // encode text (text model) - std::vector out_semantic = bark_forward_text_encoder( + bark_sequence out_semantic = bark_forward_text_encoder( tokens, model.text_model, rng, n_threads, temp, early_stop, min_eos_p); // coarse encoding (coarse model) - std::vector> out_coarse = bark_forward_coarse_encoder( + bark_codes out_coarse = bark_forward_coarse_encoder( out_semantic, model.coarse_model, rng, n_threads, temp, max_coarse_history, sliding_window_size); // fine encoding (fine model) - std::vector> out_fine = bark_forward_fine_encoder( + bark_codes out_fine = bark_forward_fine_encoder( out_coarse, model.fine_model, rng, n_threads, fine_temp); return true; diff --git a/bark.h b/bark.h index 09f0501..13634cc 100644 --- a/bark.h +++ b/bark.h @@ -47,6 +47,9 @@ struct bark_vocab { std::map id_to_subword_token; }; +typedef std::vector bark_sequence; +typedef std::vector> bark_codes; + struct gpt_layer { // normalization struct ggml_tensor * ln_1_g; @@ -120,7 +123,7 @@ bool gpt_eval( const int n_threads, const int n_past, const bool merge_ctx, - const std::vector & embd_inp, + const bark_sequence & embd_inp, std::vector & embd_w, size_t & mem_per_token); @@ -147,8 +150,8 @@ bool bark_generate_audio( const char * text, const int n_threads); -std::vector bark_forward_text_encoder( - const std::vector & tokens, +bark_sequence bark_forward_text_encoder( + const bark_sequence & tokens, const gpt_model model, std::mt19937 & rng, const int n_threads, @@ -156,7 +159,7 @@ std::vector bark_forward_text_encoder( const bool early_stop, const float min_eos_p); -std::vector> bark_forward_coarse_encoder( +bark_codes bark_forward_coarse_encoder( const std::vector & tokens, const gpt_model model, std::mt19937 & rng, @@ -165,8 +168,8 @@ std::vector> bark_forward_coarse_encoder( const int max_coarse_history, const int sliding_window_size); -std::vector> bark_forward_fine_encoder( - const std::vector> & tokens, +bark_codes bark_forward_fine_encoder( + const bark_codes & tokens, const gpt_model model, std::mt19937 & rng, const int n_threads, diff --git a/tests/test-coarse-encoder.cpp b/tests/test-coarse-encoder.cpp index 4a637bd..cd8fe16 100644 --- a/tests/test-coarse-encoder.cpp +++ b/tests/test-coarse-encoder.cpp @@ -6,17 +6,17 @@ #include #include -static const std::map, std::vector>> & k_tests() +static const std::map & k_tests() { - static const std::vector seq1 = { 215, 1988, 3275, 1898, 1898, 1898, 9372, 9372, 222, 334, 8568, 8568, 7963, 222, 8568, 55, 7963, 1270, 55, 1283, 1283, 222, 1283, 1283, 1283, 55, 1283, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 231, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 340, 5960, 5960, 5960, 5960, 1374, 4193, 4193, 9323, 1374, 1374, 1374, 1374, 4193, 1374, 4193, 1374, 1374, 4193, 1374, 231, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 8328, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 9318, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374 }; - static const std::vector seq2 = { 59, 28, 28, 107, 7409, 1999, 7695, 6486, 6486, 5836, 5836, 5836, 873, 2585, 92, 92, 59, 28, 28, 107, 315, 5623, 1025, 10, 173, 125, 7385, 147, 147, 3689, 302, 9600, 6876, 6876, 321, 41, 164, 1367, 739, 41, 10, 140, 140, 6202, 6051, 6051, 4071, 9804, 8583, 677, 3, 17, 113, 9414, 5419, 5419, 3831, 3663, 3663, 3663, 2224, 2224, 2224, 73, 9144, 9144, 1667, 1997, 1957, 1093, 825, 175, 175, 1087, 736, 1233, 230, 147, 147, 230, 230, 230, 230, 230, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 1613, 528, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 2009, 2009 }; - static const std::vector seq3 = { 10, 10, 560, 10, 9602, 10, 10, 10, 302, 2363, 2919, 6860, 5127, 7134, 7134, 3934, 3934, 3352, 3352, 3507, 50, 10, 27, 27, 3320, 6107, 9891, 9891, 9891, 321, 41, 4287, 5667, 6152, 6152, 557, 1228, 12, 12, 200, 59, 28, 28, 28, 28, 1133, 9569, 5920, 1424, 1424, 51, 51, 682, 3820, 2107, 6059, 348, 210, 10, 10, 5, 2187, 7842, 988, 1728, 1728, 438, 366, 50, 27, 27, 181, 181, 7352, 9725, 4431, 6445, 2428, 41, 41, 41, 5119, 6557, 4212, 3963, 26, 26, 934, 1025, 1024, 173, 10, 41, 5467, 6684, 6684, 6684, 4958, 41, 298, 5982, 5982, 526, 3219, 122, 181, 10, 10, 884, 3446, 2599, 4478, 4478, 2549 }; + static const bark_sequence seq1 = { 215, 1988, 3275, 1898, 1898, 1898, 9372, 9372, 222, 334, 8568, 8568, 7963, 222, 8568, 55, 7963, 1270, 55, 1283, 1283, 222, 1283, 1283, 1283, 55, 1283, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 231, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 340, 5960, 5960, 5960, 5960, 1374, 4193, 4193, 9323, 1374, 1374, 1374, 1374, 4193, 1374, 4193, 1374, 1374, 4193, 1374, 231, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 8328, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 9318, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374 }; + static const bark_sequence seq2 = { 59, 28, 28, 107, 7409, 1999, 7695, 6486, 6486, 5836, 5836, 5836, 873, 2585, 92, 92, 59, 28, 28, 107, 315, 5623, 1025, 10, 173, 125, 7385, 147, 147, 3689, 302, 9600, 6876, 6876, 321, 41, 164, 1367, 739, 41, 10, 140, 140, 6202, 6051, 6051, 4071, 9804, 8583, 677, 3, 17, 113, 9414, 5419, 5419, 3831, 3663, 3663, 3663, 2224, 2224, 2224, 73, 9144, 9144, 1667, 1997, 1957, 1093, 825, 175, 175, 1087, 736, 1233, 230, 147, 147, 230, 230, 230, 230, 230, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 1613, 528, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 2009, 2009 }; + static const bark_sequence seq3 = { 10, 10, 560, 10, 9602, 10, 10, 10, 302, 2363, 2919, 6860, 5127, 7134, 7134, 3934, 3934, 3352, 3352, 3507, 50, 10, 27, 27, 3320, 6107, 9891, 9891, 9891, 321, 41, 4287, 5667, 6152, 6152, 557, 1228, 12, 12, 200, 59, 28, 28, 28, 28, 1133, 9569, 5920, 1424, 1424, 51, 51, 682, 3820, 2107, 6059, 348, 210, 10, 10, 5, 2187, 7842, 988, 1728, 1728, 438, 366, 50, 27, 27, 181, 181, 7352, 9725, 4431, 6445, 2428, 41, 41, 41, 5119, 6557, 4212, 3963, 26, 26, 934, 1025, 1024, 173, 10, 41, 5467, 6684, 6684, 6684, 4958, 41, 298, 5982, 5982, 526, 3219, 122, 181, 10, 10, 884, 3446, 2599, 4478, 4478, 2549 }; - static const std::vector> ans1 = { {}, {} }; - static const std::vector> ans2 = { {}, {} }; - static const std::vector> ans3 = { {}, {} }; + static const bark_codes ans1 = { {}, {} }; + static const bark_codes ans2 = { {}, {} }; + static const bark_codes ans3 = { {}, {} }; - static std::map, std::vector>> _k_tests = { + static std::map _k_tests = { // { seq1, ans1 }, // hello world // { seq2, ans2 }, // this is an audio { seq3, ans3 }, // You cannot, sir, take from me anything @@ -48,7 +48,7 @@ int main(int argc, char** argv) { } for (const auto & test_kv : k_tests()) { - std::vector> res = bark_forward_coarse_encoder( + bark_codes res = bark_forward_coarse_encoder( test_kv.first, model, rng, n_threads, temp, max_coarse_history, sliding_window_size); bool correct = res.size() == test_kv.second.size(); diff --git a/tests/test-fine-encoder.cpp b/tests/test-fine-encoder.cpp index 4533546..db68f8e 100644 --- a/tests/test-fine-encoder.cpp +++ b/tests/test-fine-encoder.cpp @@ -6,17 +6,17 @@ #include #include -static const std::map>, std::vector>> & k_tests() +static const std::map & k_tests() { - static const std::vector> seq1 = {}; - static const std::vector> seq2 = {}; - static const std::vector> seq3 = {}; + static const bark_codes seq1 = {}; + static const bark_codes seq2 = {}; + static const bark_codes seq3 = {}; - static const std::vector> ans1 = { {}, {} }; - static const std::vector> ans2 = { {}, {} }; - static const std::vector> ans3 = { {}, {} }; + static const bark_codes ans1 = { {}, {} }; + static const bark_codes ans2 = { {}, {} }; + static const bark_codes ans3 = { {}, {} }; - static std::map>, std::vector>> _k_tests = { + static std::map _k_tests = { // { seq1, ans1 }, // hello world // { seq2, ans2 }, // this is an audio { seq3, ans3 }, // You cannot, sir, take from me anything @@ -45,7 +45,7 @@ int main(int argc, char** argv) { } for (const auto & test_kv : k_tests()) { - std::vector> res = bark_forward_fine_encoder( + bark_codes res = bark_forward_fine_encoder( test_kv.first, model, rng, n_threads, temp); bool correct = res.size() == test_kv.second.size(); diff --git a/tests/test-text-encoder.cpp b/tests/test-text-encoder.cpp index 5a9ed90..a960ed4 100644 --- a/tests/test-text-encoder.cpp +++ b/tests/test-text-encoder.cpp @@ -6,17 +6,17 @@ #include #include -static const std::map, std::vector> & k_tests() +static const std::map & k_tests() { - static const std::vector seq1 = { 71742, 20181, 21404 }; - static const std::vector seq2 = { 20579, 20172, 20199, 33733 }; - static const std::vector seq3 = { 21113, 35307, 10165, 62571, 10165, 23622, 20236, 20959, 52867 }; + static const bark_sequence seq1 = { 71742, 20181, 21404 }; + static const bark_sequence seq2 = { 20579, 20172, 20199, 33733 }; + static const bark_sequence seq3 = { 21113, 35307, 10165, 62571, 10165, 23622, 20236, 20959, 52867 }; - static const std::vector ans1 = { 3264, 6121, 6414, 7799, 6121, 1907, 1888, 206, 1888, 1888, 6143, 2131, 728, 6328, 3393, 5990, 2992, 8837, 206, 7799, 2533, 7374, 2992, 4059, 5990, 5990, 5236, 5939, 10, 2137, 2137, 6021, 176, 176, 2584, 499, 2382, 499, 8051, 4218, 8051, 8909, 306, 6804, 6804, 292, 292, 675, 3927, 3819, 624, 6583, 8843, 6583, 6583, 9628, 3062, 6652, 6652, 6652, 4792, 6608, 4792, 6608, 9025, 8534, 709, 7431, 7431, 9693, 8858, 8858, 3820, 8858, 682, 4076, 8996, 4909, 5682, 6139, 6139, 9133, 445, 971, 7542, 4564, 9931, 9931, 785, 785, 157, 5897, 5897, 9527, 1233, 138, 131, 10, 266, 266, 1572, 1572, 206, 206, 3533, 206, 4874, 1444, 3533, 206, 206, 7397, 206, 3252, 206, 2314, 91, 206, 7567, 841, 5346, 3252, 206, 841, 3366, 517, 517, 3252, 344, 344, 1278, 3950, 57, 57, 597, 7160, 121, 7334, 631, 292, 41, 41, 8944, 1991, 1408, 1408, 1408, 1462, 3, 166, 8745, 17, 2332, 1574, 7443, 50, 17, 27, 429, 9225, 713, 4099, 4099, 4099, 75, 555, 5932, 8870, 7627, 7627, 5661, 3088, 26, 288, 262 }; - static const std::vector ans2 = { 3205, 6179, 7731, 6972, 5722, 602, 441, 125, 147, 991, 1573, 402, 402, 6774, 1913, 8020, 8572, 8572, 1722, 5681, 1133, 4694, 1133, 7517, 9575, 8125, 5905, 6486, 1797, 6486, 5138, 5138, 4150, 2630, 2879, 59, 28, 28, 385, 1741, 4042, 9898, 302, 9600, 7231, 5673, 5475, 321, 171, 321, 164, 1025, 4681, 6202, 6752, 8288, 6747, 7656, 9804, 9804, 2411, 178, 50, 441, 6401, 5899, 79, 6511, 6511, 9629, 6511, 6154, 2224, 2224, 73, 73, 9814, 6303, 1997, 1997, 7396, 8062, 825, 441}; - static const std::vector ans3 = { 3946, 8514, 7741, 9262, 5153, 4400, 4509, 512, 1136, 4631, 8486, 4631, 3954, 7234, 993, 4412, 993, 9161, 332, 8209, 5565, 4224, 4344, 6152, 6152, 2704, 2285, 4438, 232, 131, 10, 5038, 2430, 59, 28, 28, 28, 28, 1310, 4449, 5920, 9449, 2002, 9693, 7939, 4049, 4049, 6059, 210, 100, 10, 10, 282, 3968, 988, 9790, 1728, 2587, 4405, 2948, 232, 232, 100, 3621, 8680, 417, 10, 2595, 7352, 9725, 6445, 2428, 41, 41, 10, 41, 9261, 4212, 3963, 6261, 8210, 9588, 934, 441, 1025, 2875, 8558, 6968, 116, 41, 41, 7789, 5721, 5721, 267, 2116, 579, 100, 1133, 3446, 2599, 7503, 3390, 3390, 4485, 657, 1385, 1385, 7691, 7557, 5272, 8887, 10, 6619, 3592, 6394, 5272, 5272, 8887, 1841, 602, 441, 217, 4542, 5861, 5861, 3803, 4542, 4542, 4542, 6675, 7204, 131, 100, 790, 2832, 266, 6115, 4209, 1739, 1739, 1444, 8659, 1739, 1739, 1739, 1133, 1739, 1739, 2556, 2556, 413, 413, 10, 3373, 7966, 2330, 1588, 409, 2942, 59, 28, 28, 28, 10, 28, 3160, 9569, 5920, 5887, 9693, 6290, 3458, 1242, 50, 210, 2977, 1433, 1433, 6150, 6150, 1136, 6413, 9693, 3441, 9598, 9061, 7949, 9137, 5615, 131, 100, 652, 7863, 7344, 8899, 7765, 50, 10, 100, 7399, 9915, 7557, 4509, 8486, 6264, 6133, 6133, 6133, 6619, 5210, 5210, 9629, 2555, 2339, 9486, 1425, 2762, 2466, 1079, 10 }; + static const bark_sequence ans1 = { 3264, 6121, 6414, 7799, 6121, 1907, 1888, 206, 1888, 1888, 6143, 2131, 728, 6328, 3393, 5990, 2992, 8837, 206, 7799, 2533, 7374, 2992, 4059, 5990, 5990, 5236, 5939, 10, 2137, 2137, 6021, 176, 176, 2584, 499, 2382, 499, 8051, 4218, 8051, 8909, 306, 6804, 6804, 292, 292, 675, 3927, 3819, 624, 6583, 8843, 6583, 6583, 9628, 3062, 6652, 6652, 6652, 4792, 6608, 4792, 6608, 9025, 8534, 709, 7431, 7431, 9693, 8858, 8858, 3820, 8858, 682, 4076, 8996, 4909, 5682, 6139, 6139, 9133, 445, 971, 7542, 4564, 9931, 9931, 785, 785, 157, 5897, 5897, 9527, 1233, 138, 131, 10, 266, 266, 1572, 1572, 206, 206, 3533, 206, 4874, 1444, 3533, 206, 206, 7397, 206, 3252, 206, 2314, 91, 206, 7567, 841, 5346, 3252, 206, 841, 3366, 517, 517, 3252, 344, 344, 1278, 3950, 57, 57, 597, 7160, 121, 7334, 631, 292, 41, 41, 8944, 1991, 1408, 1408, 1408, 1462, 3, 166, 8745, 17, 2332, 1574, 7443, 50, 17, 27, 429, 9225, 713, 4099, 4099, 4099, 75, 555, 5932, 8870, 7627, 7627, 5661, 3088, 26, 288, 262 }; + static const bark_sequence ans2 = { 3205, 6179, 7731, 6972, 5722, 602, 441, 125, 147, 991, 1573, 402, 402, 6774, 1913, 8020, 8572, 8572, 1722, 5681, 1133, 4694, 1133, 7517, 9575, 8125, 5905, 6486, 1797, 6486, 5138, 5138, 4150, 2630, 2879, 59, 28, 28, 385, 1741, 4042, 9898, 302, 9600, 7231, 5673, 5475, 321, 171, 321, 164, 1025, 4681, 6202, 6752, 8288, 6747, 7656, 9804, 9804, 2411, 178, 50, 441, 6401, 5899, 79, 6511, 6511, 9629, 6511, 6154, 2224, 2224, 73, 73, 9814, 6303, 1997, 1997, 7396, 8062, 825, 441}; + static const bark_sequence ans3 = { 3946, 8514, 7741, 9262, 5153, 4400, 4509, 512, 1136, 4631, 8486, 4631, 3954, 7234, 993, 4412, 993, 9161, 332, 8209, 5565, 4224, 4344, 6152, 6152, 2704, 2285, 4438, 232, 131, 10, 5038, 2430, 59, 28, 28, 28, 28, 1310, 4449, 5920, 9449, 2002, 9693, 7939, 4049, 4049, 6059, 210, 100, 10, 10, 282, 3968, 988, 9790, 1728, 2587, 4405, 2948, 232, 232, 100, 3621, 8680, 417, 10, 2595, 7352, 9725, 6445, 2428, 41, 41, 10, 41, 9261, 4212, 3963, 6261, 8210, 9588, 934, 441, 1025, 2875, 8558, 6968, 116, 41, 41, 7789, 5721, 5721, 267, 2116, 579, 100, 1133, 3446, 2599, 7503, 3390, 3390, 4485, 657, 1385, 1385, 7691, 7557, 5272, 8887, 10, 6619, 3592, 6394, 5272, 5272, 8887, 1841, 602, 441, 217, 4542, 5861, 5861, 3803, 4542, 4542, 4542, 6675, 7204, 131, 100, 790, 2832, 266, 6115, 4209, 1739, 1739, 1444, 8659, 1739, 1739, 1739, 1133, 1739, 1739, 2556, 2556, 413, 413, 10, 3373, 7966, 2330, 1588, 409, 2942, 59, 28, 28, 28, 10, 28, 3160, 9569, 5920, 5887, 9693, 6290, 3458, 1242, 50, 210, 2977, 1433, 1433, 6150, 6150, 1136, 6413, 9693, 3441, 9598, 9061, 7949, 9137, 5615, 131, 100, 652, 7863, 7344, 8899, 7765, 50, 10, 100, 7399, 9915, 7557, 4509, 8486, 6264, 6133, 6133, 6133, 6619, 5210, 5210, 9629, 2555, 2339, 9486, 1425, 2762, 2466, 1079, 10 }; - static std::map, std::vector> _k_tests = { + static std::map _k_tests = { { seq1, ans1 }, // hello world { seq2, ans2 }, // this is an audio { seq3, ans3 }, // You cannot, sir, take from me anything @@ -24,9 +24,9 @@ static const std::map, std::vector> return _k_tests; }; -std::vector pad_input(const std::vector & input) { +bark_sequence pad_input(const bark_sequence & input) { int original_sz = input.size(); - std::vector pad(input.begin(), input.end()); + bark_sequence pad(input.begin(), input.end()); pad.resize(513); for (int i = original_sz; i < 256; i++) @@ -60,8 +60,8 @@ int main(int argc, char** argv) { } for (const auto & test_kv : k_tests()) { - std::vector pad = pad_input(test_kv.first); - std::vector res = bark_forward_text_encoder( + bark_sequence pad = pad_input(test_kv.first); + bark_sequence res = bark_forward_text_encoder( pad, model, rng, n_threads, temp, true, min_eos_p); bool correct = res.size() == test_kv.second.size(); diff --git a/tests/test-tokenizer.cpp b/tests/test-tokenizer.cpp index f21423a..3f25a75 100644 --- a/tests/test-tokenizer.cpp +++ b/tests/test-tokenizer.cpp @@ -5,9 +5,9 @@ #include #include -static const std::map> & k_tests() +static const std::map & k_tests() { - static std::map> _k_tests = { + static std::map _k_tests = { { "Hello world!", { 31178, 11356, 106, }, }, { "Hello world", { 31178, 11356, }, }, { " Hello world!", { 31178, 11356, 106, }, }, @@ -35,7 +35,7 @@ int main(int argc, char **argv) { } for (const auto & test_kv : k_tests()) { - std::vector res(test_kv.first.size()); + bark_sequence res(test_kv.first.size()); int n_tokens; bert_tokenize(vocab, test_kv.first.c_str(), res.data(), &n_tokens, max_ctx_size); res.resize(n_tokens);