Skip to content

Commit

Permalink
FIX Transpose coarse tokens (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Aug 10, 2023
1 parent 4b20297 commit f5e60e2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ toy

*.o
*.plist
*.wav

.DS_Store

Expand Down
18 changes: 9 additions & 9 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ bark_codes bark_forward_coarse_encoder(
const float temp,
const int max_coarse_history,
const int sliding_window_size) {
bark_codes out_coarse(N_COARSE_CODEBOOKS);
bark_codes out_coarse;
bark_sequence out;

bark_progress progress;
Expand Down Expand Up @@ -1419,16 +1419,16 @@ bark_codes bark_forward_coarse_encoder(
BARK_ASSERT((int) out.size() == n_steps);
BARK_ASSERT(out.size() % N_COARSE_CODEBOOKS == 0);

// out_coarse: [n_codes, seq_length]
for (int i = 0; i < (int) out.size(); i++) {
if (i % 2 == 0)
out_coarse[0].push_back(out[i] - SEMANTIC_VOCAB_SIZE);
else
out_coarse[1].push_back(out[i] - SEMANTIC_VOCAB_SIZE - CODEBOOK_SIZE);
// out_coarse: [seq_length, n_codes]
for (int i = 0; i < (int) out.size(); i += N_COARSE_CODEBOOKS) {
// this assumes N_COARSE_CODEBOOKS = 2
bark_sequence _tmp = {
out[i] - SEMANTIC_VOCAB_SIZE,
out[i+1] - SEMANTIC_VOCAB_SIZE - CODEBOOK_SIZE
};
out_coarse.push_back(_tmp);
}

// TODO: transpose out_coarse

const int64_t t_main_end_us = ggml_time_us();

printf("\n\n");
Expand Down
3 changes: 2 additions & 1 deletion tests/test-forward-coarse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ int main(int argc, char** argv) {
std::string path = test_data[i];

load_test_data(path, input, truth);
bark_codes truth_t = transpose(truth);

bark_codes output = bark_forward_coarse_encoder(
input, model, rng, n_threads, 0.0f, max_coarse_history, sliding_window_size);

fprintf(stderr, "%s", path.c_str());
if (!run_test_on_codes(truth, output)) {
if (!run_test_on_codes(truth_t, output)) {
success = false;
fprintf(stderr, " TEST %d FAILED.\n", i+1);
} else {
Expand Down

0 comments on commit f5e60e2

Please sign in to comment.