Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

feat: replace $PROMPT in prompt files in non-REPL #99

Merged
merged 2 commits into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: replace $PROMPT in prompt files in non-REPL
  • Loading branch information
philpax committed Apr 2, 2023
commit 5e166854b0475402e52e9aa55a673dcd14b3412d
11 changes: 8 additions & 3 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ pub struct Args {
#[arg(long, short = 'p', default_value = None)]
pub prompt: Option<String>,

/// A file to read the prompt from. Takes precedence over `prompt` if set.
/// A file to read the prompt from.
///
/// If used with `--prompt`/`-p`, the prompt from the file will be used
/// and `$PROMPT` will be replaced with the value of `--prompt`/`-p`.
#[arg(long, short = 'f', default_value = None)]
pub prompt_file: Option<String>,

Expand Down Expand Up @@ -111,8 +114,10 @@ pub struct Args {
#[arg(long, default_value_t = false)]
pub ignore_eos: bool,

/// Dumps the prompt to console and exits, first as a comma seperated list of token IDs
/// and then as a list of comma seperated string keys and token ID values.
/// Dumps the prompt to console and exits, first as a comma-separated list of token IDs
/// and then as a list of comma-separated string keys and token ID values.
///
/// This will only work in non-`--repl` mode.
#[arg(long, default_value_t = false)]
pub dump_prompt_tokens: bool,
}
Expand Down
27 changes: 18 additions & 9 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rustyline::error::ReadlineError;
mod cli_args;

fn repl_mode(
prompt: &str,
raw_prompt: &str,
model: &llama_rs::Model,
vocab: &llama_rs::Vocabulary,
params: &InferenceParameters,
Expand All @@ -22,7 +22,7 @@ fn repl_mode(
let readline = rl.readline(">> ");
match readline {
Ok(line) => {
let prompt = prompt.replace("$PROMPT", &line);
let prompt = process_prompt(raw_prompt, &line);
let mut rng = thread_rng();

let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string());
Expand Down Expand Up @@ -126,7 +126,7 @@ fn main() {
}
};

let prompt = if let Some(path) = &args.prompt_file {
let raw_prompt = if let Some(path) = &args.prompt_file {
match std::fs::read_to_string(path) {
Ok(mut prompt) => {
// Strip off the last character if it's exactly newline. Also strip off a single
Expand Down Expand Up @@ -207,11 +207,6 @@ fn main() {

log::info!("Model fully loaded!");

if args.dump_prompt_tokens {
dump_tokens(&prompt, &vocab).ok();
return;
}

let mut rng = if let Some(seed) = CLI_ARGS.seed {
rand::rngs::StdRng::seed_from_u64(seed)
} else {
Expand Down Expand Up @@ -241,8 +236,18 @@ fn main() {
};

if args.repl {
repl_mode(&prompt, &model, &vocab, &inference_params, session);
repl_mode(&raw_prompt, &model, &vocab, &inference_params, session);
} else {
let prompt = match (&args.prompt_file, &args.prompt) {
(Some(_), Some(prompt)) => process_prompt(&raw_prompt, prompt),
_ => raw_prompt,
};

if args.dump_prompt_tokens {
dump_tokens(&prompt, &vocab).ok();
return;
}

let inference_params = if session_loaded {
InferenceParameters {
play_back_previous_tokens: true,
Expand Down Expand Up @@ -328,3 +333,7 @@ mod snapshot {
snap.write(&mut writer)
}
}

fn process_prompt(raw_prompt: &str, prompt: &str) -> String {
raw_prompt.replace("$PROMPT", prompt)
}