Skip to content

Commit

Permalink
FIX/API Proper constructor for bark_progress (PABannier#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Sep 9, 2023
1 parent dbe6aac commit badd010
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ struct bark_context {

struct bark_progress {
float current = 0.0f;
const char * func = NULL;
const char * func;

bark_progress() {}
bark_progress(const char * func): func(func) {}

void callback(float progress) {
float percentage = progress * 100;
if (percentage == 0.0f && func != NULL) {
if (percentage == 0.0f) {
fprintf(stderr, "%s: ", func);
}
while (percentage > current) {
Expand Down Expand Up @@ -1707,8 +1707,7 @@ void bark_forward_text_encoder(struct bark_context * ctx, int n_threads) {

bark_sequence out;

bark_progress progress;
progress.func = __func__;
bark_progress progress( __func__);

gpt_model * model = &ctx->model.text_model;

Expand Down Expand Up @@ -1768,8 +1767,7 @@ void bark_forward_coarse_encoder(struct bark_context * ctx, int n_threads) {
bark_codes out_coarse;
bark_sequence out;

bark_progress progress;
progress.func = __func__;
bark_progress progress(__func__);

int max_coarse_history = ctx->max_coarse_history;
int sliding_window_size = ctx->sliding_window_size;
Expand Down Expand Up @@ -1880,9 +1878,11 @@ void bark_forward_coarse_encoder(struct bark_context * ctx, int n_threads) {
}

void bark_forward_fine_encoder(struct bark_context * ctx, int n_threads) {
// input shape: (N, n_codes)
// input shape: [N, n_codes]
const int64_t t_main_start_us = ggml_time_us();

bark_progress progress(__func__);

bark_codes input = ctx->coarse_tokens;

float temp = ctx->fine_temp;
Expand All @@ -1892,9 +1892,6 @@ void bark_forward_fine_encoder(struct bark_context * ctx, int n_threads) {

gpt_model * model = &ctx->model.fine_model;

bark_progress progress;
progress.func = __func__;

int n_coarse = input[0].size();
int original_seq_len = input.size();
int n_remove_from_end = 0;
Expand Down

0 comments on commit badd010

Please sign in to comment.