Skip to content

Commit

Permalink
DOC/EX Expose params to user for the main script (PABannier#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Aug 12, 2023
1 parent 00ff99b commit 7fcd8be
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 9 deletions.
36 changes: 36 additions & 0 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1703,3 +1703,39 @@ bool bark_generate_audio(

return true;
}

bool bark_params_parse(int argc, char ** argv, bark_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];

if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i];
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-h" || arg == "--help") {
bark_print_usage(argv, params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
bark_print_usage(argv, params);
exit(0);
}
}

return true;
}

void bark_print_usage(char ** argv, const bark_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
}
17 changes: 15 additions & 2 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <map>
#include <random>
#include <thread>
#include <vector>

#define SAMPLE_RATE 24000
Expand All @@ -26,6 +27,14 @@
#define COARSE_SEMANTIC_PAD_TOKEN 12048
#define COARSE_INFER_TOKEN 12050

struct bark_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

std::string model = "./ggml_weights/"; // weights location

std::string prompt; // user prompt
};

struct gpt_hparams {
int32_t n_in_vocab;
int32_t n_out_vocab;
Expand Down Expand Up @@ -150,8 +159,8 @@ bool bark_model_load(const std::string & dirname, bark_model & model);
bool bark_vocab_load(const std::string & fname, bark_vocab& vocab, int32_t expected_size);

void bert_tokenize(
const bark_vocab & vocab,
const char * text,
const bark_vocab & vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
int32_t n_max_tokens);
Expand Down Expand Up @@ -212,3 +221,7 @@ struct bark_progress {
}
}
};

bool bark_params_parse(int argc, char ** argv, bark_params & params);

void bark_print_usage(char ** argv, const bark_params & params);
26 changes: 19 additions & 7 deletions examples/main.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
#include "ggml.h"
#include "bark.h"

int main() {

int main(int argc, char **argv) {
const int64_t t_main_start_us = ggml_time_us();

bark_params params;

if (bark_params_parse(argc, argv, params) == false) {
return 1;
}

int64_t t_load_us = 0;
int64_t t_eval_us = 0;

bark_model model;
std::string fname = "./ggml_weights";

if (!params.model.empty()) {
fname = params.model;
}

// load the model
{
const int64_t t_start_us = ggml_time_us();
Expand All @@ -24,14 +35,15 @@ int main() {

printf("\n");

// forward pass
const std::string prompt = "hi! i'm john and i'm a software engineer.";
{
const int64_t t_eval_us_start = ggml_time_us();
bark_generate_audio(model, model.vocab, prompt.data(), 4);
t_eval_us = ggml_time_us() - t_eval_us_start;
std::string prompt = "this is an audio";
if (!params.prompt.empty()) {
prompt = params.prompt;
}

const int64_t t_eval_us_start = ggml_time_us();
bark_generate_audio(model, model.vocab, prompt.data(), params.n_threads);
t_eval_us = ggml_time_us() - t_eval_us_start;

// report timing
{
const int64_t t_main_end_us = ggml_time_us();
Expand Down

0 comments on commit 7fcd8be

Please sign in to comment.