Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: seed is passed to bark_context #161

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ static bool bark_load_model_from_file(
return true;
}

struct bark_context* bark_load_model(const std::string& model_path, bark_verbosity_level verbosity) {
struct bark_context* bark_load_model(const std::string& model_path, bark_verbosity_level verbosity, uint32_t seed) {
int64_t t_load_start_us = ggml_time_us();

struct bark_context* bctx = new bark_context();
Expand All @@ -1083,7 +1083,7 @@ struct bark_context* bark_load_model(const std::string& model_path, bark_verbosi

bark_context_params params = bark_context_default_params();
params.verbosity = verbosity;
bctx->rng = std::mt19937(params.seed);
bctx->rng = std::mt19937(seed);
bctx->params = params;
bctx->t_load_us = ggml_time_us() - t_load_start_us;

Expand Down Expand Up @@ -2136,7 +2136,6 @@ void bark_free(struct bark_context* bctx) {

struct bark_context_params bark_context_default_params() {
struct bark_context_params result = {
/*.seed =*/0,
/*.verbosity =*/bark_verbosity_level::LOW,
/*.temp =*/0.7,
/*.fine_temp =*/0.5,
Expand Down
6 changes: 3 additions & 3 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ struct bark_model {
};

struct bark_context_params {
// RNG seed
uint32_t seed;
// Verbosity level
bark_verbosity_level verbosity;

Expand Down Expand Up @@ -217,11 +215,13 @@ struct bark_context_params bark_context_default_params(void);
*
* @param model_path The directory path of the bark model to load.
* @param verbosity The verbosity level when loading the model.
* @param seed The seed to use for random number generation.
* @return A pointer to the loaded bark model context.
*/
struct bark_context *bark_load_model(
const std::string &model_path,
bark_verbosity_level verbosity);
bark_verbosity_level verbosity,
uint32_t seed);

/**
* Generates an audio file from the given text using the specified Bark context.
Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ int main(int argc, char **argv) {
<< "\n";

// initialize bark context
struct bark_context *bctx = bark_load_model(params.model_path, verbosity);
struct bark_context *bctx = bark_load_model(params.model_path, verbosity, params.seed);
if (!bctx) {
fprintf(stderr, "%s: Could not load model\n", __func__);
exit(1);
Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ int main(int argc, char **argv) {

bark_params_parse(argc, argv, params);

struct bark_context *bctx = bark_load_model(params.model_path.c_str(), bark_verbosity_level::LOW);
struct bark_context *bctx = bark_load_model(params.model_path.c_str(), bark_verbosity_level::LOW, params.seed);
if (!bctx) {
fprintf(stderr, "%s: Could not load model\n", __func__);
return 1;
Expand Down
2 changes: 1 addition & 1 deletion examples/wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ EMSCRIPTEN_BINDINGS(bark) {

for (size_t i = 0; i < g_contexts.size(); i++) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = bark_load_model(path_model.c_str(), bark_verbosity_level::LOW);
g_contexts[i] = bark_load_model(path_model.c_str(), bark_verbosity_level::LOW, 0);
if (g_contexts[i] != nullptr) {
return i + 1;
} else {
Expand Down
Loading