Skip to content

Commit

Permalink
FIX Text encoder (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Jul 14, 2023
1 parent 68964f5 commit daf11f6
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,12 +510,7 @@ bool gpt_eval(
// memory_k and memory_v)
embd = ggml_get_rows(ctx0, model.wtes[0], input);
} else {
// first step

// tok_emb = torch.cat([
// self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
// self.transformer.wte(idx[:,256+256:])
// ], dim=1)
// first step (context merging)
struct ggml_tensor * seq_embd = ggml_get_rows(ctx0, model.wtes[0], ggml_view_1d(ctx0, input, 256, 0));
struct ggml_tensor * ctx_embd = ggml_get_rows(ctx0, model.wtes[0], ggml_view_1d(ctx0, input, 256, 256*ggml_element_size(input)));
struct ggml_tensor * rem_embd = ggml_get_rows(ctx0, model.wtes[0], ggml_view_1d(ctx0, input, 1, 512*ggml_element_size(input)));
Expand Down Expand Up @@ -769,6 +764,18 @@ bool gpt_eval(
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);

// if (toy) {
// for (int i = 0; i < toy->ne[1]; i++) {
// for (int j = 0; j < toy->ne[0]; j++) {
// float v = *(float *) ((char *)toy->data + i*toy->nb[1] + j*toy->nb[0]);
// printf("%.4f ", v);
// }
// printf("\n\n");
// }

// printf("dim: [%d, %d]\n", toy->ne[0], toy->ne[1]);
// }

// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
Expand Down Expand Up @@ -847,7 +854,10 @@ bool bark_generate_audio(
const float temp = 0.7;

const int early_stop = true;
const float min_eos_p = 0.2;

// in the original implementation, min_eos_p=0.2, yet for bark.cpp this seems too
// high and this generates overly long sequence.
const float min_eos_p = 0.15;

std::mt19937 rng(seed);

Expand Down Expand Up @@ -921,14 +931,15 @@ bool bark_generate_audio(
input.push_back(sampled_id);
output.push_back(sampled_id);

printf("%d ", sampled_id);
fflush(stdout);

if (early_stop && ((sampled_id == SEMANTIC_VOCAB_SIZE) || (eos_p > min_eos_p)))
break;
}

for(int i = 0; i < output.size(); i++) {
assert((output[i] > 0) && (output[i] < SEMANTIC_VOCAB_SIZE));
printf("%d ", output[i]);
}
printf("\n\ntext semantic sequence length: %d\n", output.size());

}

// generate audio (encodec)
Expand Down

0 comments on commit daf11f6

Please sign in to comment.