Skip to content

Commit

Permalink
feat: add "stop" keywords as alternative to eot token
Browse files Browse the repository at this point in the history
  • Loading branch information
Claude Doppler committed Apr 8, 2023
1 parent 62cfc54 commit 9fd062f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
12 changes: 10 additions & 2 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.antiprompt.push_back(argv[i]);
} else if (arg == "--stop") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.stop_keywords.push_back(argv[i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {
Expand Down Expand Up @@ -209,8 +215,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT");
fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n");
fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n");
fprintf(stderr, " (can be specified more than once for multiple keywords).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for <= 0)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct gpt_params {


std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keywords; // string upon seeing which the model will stop

bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
Expand Down
46 changes: 43 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}

if (params.stop_keywords.size()) {
for (auto stop_keyword : params.stop_keywords) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}

fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
Expand Down Expand Up @@ -344,13 +351,28 @@ int main(int argc, char ** argv) {
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {

// check for reverse prompt
if (params.antiprompt.size()) {
std::string last_output;
std::string last_output;
if (params.antiprompt.size() || params.stop_keywords.size()) {
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
}

// Check for stop keywords, a configurable alternative to the end-of-text token
// This should stop also the interactive mode, useful to stop interactive mode without SIGTERM
bool stop = false;
for (std::string stop_keyword : params.stop_keywords) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if (stop) {
break;
}

// check for reverse prompt
if (params.antiprompt.size()) {
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
Expand Down Expand Up @@ -430,6 +452,24 @@ int main(int argc, char ** argv) {
}
}

// Check for stop keywords, a configurable alternative to the end-of-text token
if (!params.interactive && params.stop_keywords.size() && !is_interacting) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
bool stop = false;
for (std::string stop_keyword : params.stop_keywords) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if (stop) {
break;
}
}

// end of text token
if (!embd.empty() && embd.back() == llama_token_eos()) {
if (params.instruct) {
Expand Down

0 comments on commit 9fd062f

Please sign in to comment.