diff --git a/bark.cpp b/bark.cpp index 2171553..58c7813 100644 --- a/bark.cpp +++ b/bark.cpp @@ -1070,7 +1070,7 @@ static bool bark_load_model_from_file( return true; } -struct bark_context* bark_load_model(const std::string& model_path, bark_verbosity_level verbosity) { +struct bark_context* bark_load_model(const std::string& model_path, bark_verbosity_level verbosity, uint32_t seed) { int64_t t_load_start_us = ggml_time_us(); struct bark_context* bctx = new bark_context(); @@ -1083,7 +1083,7 @@ struct bark_context* bark_load_model(const std::string& model_path, bark_verbosi bark_context_params params = bark_context_default_params(); params.verbosity = verbosity; - bctx->rng = std::mt19937(params.seed); + bctx->rng = std::mt19937(seed); bctx->params = params; bctx->t_load_us = ggml_time_us() - t_load_start_us; @@ -2136,7 +2136,6 @@ void bark_free(struct bark_context* bctx) { struct bark_context_params bark_context_default_params() { struct bark_context_params result = { - /*.seed =*/0, /*.verbosity =*/bark_verbosity_level::LOW, /*.temp =*/0.7, /*.fine_temp =*/0.5, diff --git a/bark.h b/bark.h index 0216372..a3722c9 100644 --- a/bark.h +++ b/bark.h @@ -113,8 +113,6 @@ struct bark_model { }; struct bark_context_params { - // RNG seed - uint32_t seed; // Verbosity level bark_verbosity_level verbosity; @@ -217,11 +215,13 @@ struct bark_context_params bark_context_default_params(void); * * @param model_path The directory path of the bark model to load. * @param verbosity The verbosity level when loading the model. + * @param seed The seed to use for random number generation. * @return A pointer to the loaded bark model context. */ struct bark_context *bark_load_model( const std::string &model_path, - bark_verbosity_level verbosity); + bark_verbosity_level verbosity, + uint32_t seed); /** * Generates an audio file from the given text using the specified Bark context. diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e58f7d8..58fb4d1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -31,7 +31,7 @@ int main(int argc, char **argv) { << "\n"; // initialize bark context - struct bark_context *bctx = bark_load_model(params.model_path, verbosity); + struct bark_context *bctx = bark_load_model(params.model_path, verbosity, params.seed); if (!bctx) { fprintf(stderr, "%s: Could not load model\n", __func__); exit(1); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4d9c3fd..af726d8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -93,7 +93,7 @@ int main(int argc, char **argv) { bark_params_parse(argc, argv, params); - struct bark_context *bctx = bark_load_model(params.model_path.c_str(), bark_verbosity_level::LOW); + struct bark_context *bctx = bark_load_model(params.model_path.c_str(), bark_verbosity_level::LOW, params.seed); if (!bctx) { fprintf(stderr, "%s: Could not load model\n", __func__); return 1; diff --git a/examples/wasm/emscripten.cpp b/examples/wasm/emscripten.cpp index 46b2cb0..2a53084 100644 --- a/examples/wasm/emscripten.cpp +++ b/examples/wasm/emscripten.cpp @@ -18,7 +18,7 @@ EMSCRIPTEN_BINDINGS(bark) { for (size_t i = 0; i < g_contexts.size(); i++) { if (g_contexts[i] == nullptr) { - g_contexts[i] = bark_load_model(path_model.c_str(), bark_verbosity_level::LOW); + g_contexts[i] = bark_load_model(path_model.c_str(), bark_verbosity_level::LOW, 0); if (g_contexts[i] != nullptr) { return i + 1; } else {