Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT: Add unit tests for coarse model #28

Merged
merged 5 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ encodec
main
tests/test-tokenizer
tests/test-text-encoder
tests/test-coarse-encoder

*.o
*.plist
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BUILD_TARGETS = main

# Binaries only useful for tests
TEST_TARGETS = tests/test-tokenizer tests/test-text-encoder
TEST_TARGETS = tests/test-tokenizer tests/test-text-encoder tests/test-coarse-encoder

default: $(BUILD_TARGETS)

Expand Down Expand Up @@ -318,3 +318,6 @@ tests/test-tokenizer: tests/test-tokenizer.cpp ggml.o bark.o encodec.o $(OBJS)

tests/test-text-encoder: tests/test-text-encoder.cpp ggml.o bark.o encodec.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)

tests/test-coarse-encoder: tests/test-coarse-encoder.cpp ggml.o bark.o encodec.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
199 changes: 108 additions & 91 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1251,15 +1251,15 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(

input.clear();

bark_vocab::id sampled_id = gpt_sample(logits, temp, rng, &eos_p);
bark_vocab::id next = gpt_sample(logits, temp, rng, &eos_p);

if (early_stop && ((sampled_id == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p)))
if (early_stop && ((next == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p)))
break;

input.push_back(sampled_id);
out.push_back(sampled_id);
input.push_back(next);
out.push_back(next);

printf("%d ", sampled_id);
printf("%d ", next);
fflush(stdout);
}

Expand All @@ -1268,6 +1268,107 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
return out;
}

std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
const int n_threads,
const float temp,
const bool early_stop,
const float min_eos_p,
const int max_coarse_history,
const int sliding_window_size) {
std::vector<std::vector<bark_vocab::id>> out_coarse(N_COARSE_CODEBOOKS);
std::vector<bark_vocab::id> out;

std::mt19937 rng(0);

float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS;
int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio);

int n_steps = floorf(tokens.size() * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS) * N_COARSE_CODEBOOKS;
int step_ix = 0;

BARK_ASSERT(n_steps > 0);
BARK_ASSERT(n_steps % N_COARSE_CODEBOOKS == 0);

int n_window_steps = ceilf(static_cast<float>(n_steps) / sliding_window_size);

std::vector<bark_vocab::id> input = tokens;
std::vector<float> logits;

// dry run to estimate mem_per_token
size_t mem_per_token = 0;
gpt_eval(model, n_threads, 0, false, { 0, 1, 2, 3 }, logits, mem_per_token);

for(int i = 0; i < n_window_steps; i++) {
int semantic_ix = roundf(n_steps / semantic_to_coarse_ratio);

std::vector<bark_vocab::id> input_in(
input.begin() + std::max(semantic_ix-max_semantic_history, 0),
input.end()
);
size_t original_size = input_in.size();
input_in.resize(256);

// padding from the right side
for (int ix = original_size; ix < 256; ix++)
input_in[ix] = COARSE_SEMANTIC_PAD_TOKEN;

input_in.push_back(COARSE_INFER_TOKEN);

// concatenate input_in and input_coarse
input_in.insert(
input_in.end(),
std::make_move_iterator(out.end() - std::min(max_coarse_history, (int) out.size())),
std::make_move_iterator(out.end())
);

int n_past = 0;
mem_per_token *= 1.1; // context length is growing, mem_per_token must grow as well

for (int j = 0; j < sliding_window_size; j++) {
if (step_ix >= n_steps)
continue;

gpt_eval(model, n_threads, n_past, false, input_in, logits, mem_per_token);

n_past += input_in.size();
input_in.clear();

bool is_major = step_ix % N_COARSE_CODEBOOKS == 0;
int start_ix = SEMANTIC_VOCAB_SIZE + (1 - is_major) * CODEBOOK_SIZE;
int end_ix = SEMANTIC_VOCAB_SIZE + (2 - is_major) * CODEBOOK_SIZE;
std::vector<float> relevant_logits(logits.begin() + start_ix, logits.begin() + end_ix);

bark_vocab::id next = gpt_sample(relevant_logits, temp, rng, NULL);
next += start_ix;

input_in.push_back(next);
out.push_back(next);

printf("%d ", next);
fflush(stdout);

step_ix += 1;
}
}

BARK_ASSERT((int) out.size() == n_steps);
BARK_ASSERT(out.size() % N_COARSE_CODEBOOKS == 0);

for (int i = 0; i < (int) out.size(); i++) {
if (i % 2 == 0)
out_coarse[0].push_back(out[i] - SEMANTIC_VOCAB_SIZE);
else
out_coarse[1].push_back(out[i] - SEMANTIC_VOCAB_SIZE - CODEBOOK_SIZE);
}

printf("\n\ncoarse sequence length: %zu\n\n", out.size());

return out_coarse;
}


bool bark_generate_audio(
bark_model model,
const bark_vocab& vocab,
Expand Down Expand Up @@ -1334,92 +1435,8 @@ bool bark_generate_audio(
tokens, model.text_model, n_threads, temp, early_stop, min_eos_p);

// coarse encoding (coarse model)
std::vector<std::vector<bark_vocab::id>> out_coarse(N_COARSE_CODEBOOKS);
{
std::vector<bark_vocab::id> out;
float semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS;
int max_semantic_history = floorf(max_coarse_history / semantic_to_coarse_ratio);

int n_steps = floorf(out_semantic.size() * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS) * N_COARSE_CODEBOOKS;
int step_ix = 0;

BARK_ASSERT(n_steps > 0);
BARK_ASSERT(n_steps % N_COARSE_CODEBOOKS == 0);

int n_window_steps = ceilf(static_cast<float>(n_steps) / sliding_window_size);

std::vector<bark_vocab::id> input = out_semantic;
std::vector<float> logits;

// dry run to estimate mem_per_token
size_t mem_per_token = 0;
gpt_eval(model.coarse_model, n_threads, 0, false, { 0, 1, 2, 3 }, logits, mem_per_token);

for(int i = 0; i < n_window_steps; i++) {
int semantic_ix = roundf(n_steps / semantic_to_coarse_ratio);

std::vector<bark_vocab::id> input_in(input.begin() + std::max(semantic_ix-max_semantic_history, 0), input.end());
size_t original_size = input_in.size();
input_in.resize(256);

// padding from the right side
for (int ix = original_size; ix < 256; ix++)
input_in[ix] = COARSE_SEMANTIC_PAD_TOKEN;

input_in.push_back(COARSE_INFER_TOKEN);

// concatenate input_in and input_coarse
input_in.insert(
input_in.end(),
std::make_move_iterator(out.end() - std::min(max_coarse_history, (int) out.size())),
std::make_move_iterator(out.end())
);

int n_past = 0;
mem_per_token *= 1.1; // context length is growing, mem_per_token must grow as well

for (int j = 0; j < sliding_window_size; j++) {
if (step_ix >= n_steps)
continue;

gpt_eval(model.coarse_model, n_threads, n_past, false, input_in, logits, mem_per_token);

n_past += input_in.size();
input_in.clear();

bool is_major = step_ix % N_COARSE_CODEBOOKS == 0;
int start_ix = SEMANTIC_VOCAB_SIZE + (1 - is_major) * CODEBOOK_SIZE;
int end_ix = SEMANTIC_VOCAB_SIZE + (2 - is_major) * CODEBOOK_SIZE;
std::vector<float> relevant_logits(logits.begin() + start_ix, logits.begin() + end_ix);

bark_vocab::id sampled_id = gpt_sample(relevant_logits, temp, rng, NULL);
sampled_id += start_ix;

input_in.push_back(sampled_id);
out.push_back(sampled_id);

printf("%d ", sampled_id);
fflush(stdout);

step_ix += 1;
}
}

BARK_ASSERT((int) out.size() == n_steps);
BARK_ASSERT(out.size() % N_COARSE_CODEBOOKS == 0);

// gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
int mid = out.size() / 2;
for(int ix = 0; ix < (int) out.size(); ix++) {
if(ix < mid)
out_coarse[0].push_back(out[ix] - SEMANTIC_VOCAB_SIZE);
else
// !! this only works when CODEBOOK_SIZE = 2
out_coarse[1].push_back(out[ix] - SEMANTIC_VOCAB_SIZE - CODEBOOK_SIZE);
}

printf("\n\ncoarse sequence length: %zu\n\n", out.size());
}
std::vector<std::vector<bark_vocab::id>> out_coarse = bark_forward_coarse_encoder(
out_semantic, model.coarse_model, n_threads, temp, early_stop, min_eos_p, max_coarse_history, sliding_window_size);

// fine encoding (fine model)
std::vector<std::vector<bark_vocab::id>> out_fine;
Expand Down
10 changes: 10 additions & 0 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,13 @@ std::vector<bark_vocab::id> bark_forward_text_encoder(
const float temp,
const bool early_stop,
const float min_eos_p);

std::vector<std::vector<bark_vocab::id>> bark_forward_coarse_encoder(
const std::vector<bark_vocab::id> & tokens,
const gpt_model model,
const int n_threads,
const float temp,
const bool early_stop,
const float min_eos_p,
const int max_coarse_history,
const int sliding_window_size);
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ endfunction()

bark_add_test(test-tokenizer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../ggml_weights/ggml_vocab.bin)
bark_add_test(test-text-encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../ggml_weights/ggml_weights_text.bin)
bark_add_test(test-coarse-encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../ggml_weights/ggml_weights_coarse.bin)
91 changes: 91 additions & 0 deletions tests/test-coarse-encoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "bark.h"

#include <cstdio>
#include <string>
#include <map>
#include <random>
#include <vector>

static const std::map<std::vector<bark_vocab::id>, std::vector<std::vector<bark_vocab::id>>> & k_tests()
{
static const std::vector<bark_vocab::id> seq1 = { 215, 1988, 3275, 1898, 1898, 1898, 9372, 9372, 222, 334, 8568, 8568, 7963, 222, 8568, 55, 7963, 1270, 55, 1283, 1283, 222, 1283, 1283, 1283, 55, 1283, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 231, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 5960, 340, 5960, 5960, 5960, 5960, 1374, 4193, 4193, 9323, 1374, 1374, 1374, 1374, 4193, 1374, 4193, 1374, 1374, 4193, 1374, 231, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 8328, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 9318, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374, 1374 };
static const std::vector<bark_vocab::id> seq2 = { 59, 28, 28, 107, 7409, 1999, 7695, 6486, 6486, 5836, 5836, 5836, 873, 2585, 92, 92, 59, 28, 28, 107, 315, 5623, 1025, 10, 173, 125, 7385, 147, 147, 3689, 302, 9600, 6876, 6876, 321, 41, 164, 1367, 739, 41, 10, 140, 140, 6202, 6051, 6051, 4071, 9804, 8583, 677, 3, 17, 113, 9414, 5419, 5419, 3831, 3663, 3663, 3663, 2224, 2224, 2224, 73, 9144, 9144, 1667, 1997, 1957, 1093, 825, 175, 175, 1087, 736, 1233, 230, 147, 147, 230, 230, 230, 230, 230, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 1613, 528, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 1613, 2009, 2009 };
static const std::vector<bark_vocab::id> seq3 = { 10, 10, 560, 10, 9602, 10, 10, 10, 302, 2363, 2919, 6860, 5127, 7134, 7134, 3934, 3934, 3352, 3352, 3507, 50, 10, 27, 27, 3320, 6107, 9891, 9891, 9891, 321, 41, 4287, 5667, 6152, 6152, 557, 1228, 12, 12, 200, 59, 28, 28, 28, 28, 1133, 9569, 5920, 1424, 1424, 51, 51, 682, 3820, 2107, 6059, 348, 210, 10, 10, 5, 2187, 7842, 988, 1728, 1728, 438, 366, 50, 27, 27, 181, 181, 7352, 9725, 4431, 6445, 2428, 41, 41, 41, 5119, 6557, 4212, 3963, 26, 26, 934, 1025, 1024, 173, 10, 41, 5467, 6684, 6684, 6684, 4958, 41, 298, 5982, 5982, 526, 3219, 122, 181, 10, 10, 884, 3446, 2599, 4478, 4478, 2549 };

static const std::vector<std::vector<bark_vocab::id>> ans1 = { {}, {} };
static const std::vector<std::vector<bark_vocab::id>> ans2 = { {}, {} };
static const std::vector<std::vector<bark_vocab::id>> ans3 = { {}, {} };

static std::map<std::vector<bark_vocab::id>, std::vector<std::vector<bark_vocab::id>>> _k_tests = {
// { seq1, ans1 }, // hello world
// { seq2, ans2 }, // this is an audio
{ seq3, ans3 }, // You cannot, sir, take from me anything
};
return _k_tests;
};

int main(int argc, char** argv) {
if (argc < 2) {
fprintf(stderr, "Usage: %s <model-file>\n", argv[0]);
return 1;
}

const std::string fname = argv[1];

gpt_model model;

const int n_threads = 4;
const float min_eos_p = 0.2;
const float temp = 0.7;

const int max_coarse_history = 630;
const int sliding_window_size = 60;

printf("%s: reading bark coarse model\n", __func__);
if(!gpt_model_load(fname, model)) {
fprintf(stderr, "%s: invalid model file '%s'\n", __func__, fname.c_str());
return 1;
}

for (const auto & test_kv : k_tests()) {
std::vector<std::vector<bark_vocab::id>> res = bark_forward_coarse_encoder(
test_kv.first, model, n_threads, temp, true, min_eos_p, max_coarse_history, sliding_window_size);

bool correct = res.size() == test_kv.second.size();

for (int i = 0; i < (int) res.size() && correct; ++i) {
correct = res[i].size() == test_kv.second[i].size();
for (int j = 0; j < (int) res[i].size() && correct; j++) {
if (res[i][j] != test_kv.second[i][j]) {
correct = false;
}
}
}

if (!correct) {
fprintf(stderr, "%s : failed test \n", __func__);
fprintf(stderr, "%s : expected tokens (n=%zu): ", __func__, test_kv.second.size());
for (int i = 0; i < (int) test_kv.second.size(); i++) {
for (int j = 0; j < (int) test_kv.second[i].size(); j++) {
fprintf(stderr, "%d ", test_kv.second[i][j]);
}
fprintf(stderr, "\n");
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : got tokens (n=%zu): ", __func__, res.size());
for (int i = 0; i < (int) res.size(); i++) {
for (int j = 0; j < (int) res[i].size(); j++) {
fprintf(stderr, "%d ", res[i][j]);
}
fprintf(stderr, "\n");
}
fprintf(stderr, "\n");

return 3;
}
}

fprintf(stderr, "%s : tests passed successfully.\n", __func__);

return 0;
}
Loading