Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: OpenAI: JSON mode #14

Merged
merged 6 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: OpenAI: Json mode
  • Loading branch information
Butch78 committed Nov 14, 2023
commit a232325ad51209e42ba67a6616083b402324dcb1
16 changes: 16 additions & 0 deletions orca/src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use candle_core::{Device, Result as CandleResult, Tensor};

use crate::prompt::Prompt;

use self::openai::JsonModeResponse;

/// Generate with context trait is used to execute an LLM using a context and a prompt template.
/// The context is a previously created context using the Context struct. The prompt template
/// is a previously created prompt template using the template! macro.
Expand Down Expand Up @@ -125,6 +127,9 @@ pub enum LLMResponse {
/// OpenAI response
OpenAI(openai::Response),

/// OpenAI json mode response
OpenAIJson(openai::JsonModeResponse),
santiagomed marked this conversation as resolved.
Show resolved Hide resolved

/// Quantized model response
Quantized(String),

Expand All @@ -140,6 +145,13 @@ impl From<Response> for LLMResponse {
}
}

impl From<JsonModeResponse> for LLMResponse {
fn from(json_response: JsonModeResponse) -> Self {
// Convert JsonModeResponse to LLMResponse variant
LLMResponse::OpenAIJson(json_response)
}
}

impl EmbeddingResponse {
pub fn to_vec(&self) -> Result<Vec<f32>> {
match self {
Expand Down Expand Up @@ -198,6 +210,7 @@ impl LLMResponse {
pub fn to_role(&self) -> String {
match self {
LLMResponse::OpenAI(response) => response.to_string(),
LLMResponse::OpenAIJson(response) => response.to_string(),
LLMResponse::Quantized(_) => "ai".to_string(),
LLMResponse::Empty => panic!("empty response does not have a role"),
}
Expand All @@ -211,6 +224,9 @@ impl Display for LLMResponse {
LLMResponse::OpenAI(response) => {
write!(f, "{}", response)
}
LLMResponse::OpenAIJson(response) => {
write!(f, "{}", response)
}
LLMResponse::Quantized(response) => {
write!(f, "{}", response)
}
Expand Down
87 changes: 86 additions & 1 deletion orca/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
prompt::{chat::Message, Prompt},
};
use anyhow::Result;
use pdf::content::Op;
use reqwest::Client;
use serde::{Deserialize, Serialize};

Expand All @@ -21,6 +22,8 @@ pub struct Payload {
stop: Option<Vec<String>>,
messages: Vec<Message>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
}

#[derive(Serialize, Deserialize, Debug)]
Expand All @@ -39,6 +42,17 @@ pub struct Response {
choices: Vec<Choice>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct JsonModeResponse {
id: String,
santiagomed marked this conversation as resolved.
Show resolved Hide resolved
choices: Vec<Choice>,
created: i64,
model: String,
object: String,
system_fingerprint: String,
usage: Usage,
}

#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct OpenAIEmbeddingResponse {
object: String,
Expand Down Expand Up @@ -90,13 +104,29 @@ impl Display for Response {
}
}

impl Display for JsonModeResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = String::new();
for choice in &self.choices {
s.push_str(&choice.message.content);
}
write!(f, "{}", s)
}
}

#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct Usage {
prompt_tokens: i32,
completion_tokens: Option<i32>,
total_tokens: i32,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub type_: String,
Butch78 marked this conversation as resolved.
Show resolved Hide resolved
}

#[derive(Serialize, Deserialize, Debug)]
pub struct Choice {
index: i32,
Expand Down Expand Up @@ -151,20 +181,25 @@ pub struct OpenAI {
///
/// The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.
max_tokens: u16,

/// The format of the returned data. With the new update, the response can be set to a JSON object.
/// https://platform.openai.com/docs/guides/text-generation/json-mode
response_format: Option<ResponseFormat>,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick:
Hmm, I am wondering if it will be better without Option...

Suggested change
response_format: Option<ResponseFormat>,
response_format: ResponseFormat,

Maybe it is possible to override the default zero value. So users don't have to set it, at the same time may read more about default format.

}

impl Default for OpenAI {
fn default() -> Self {
Self {
client: Client::new(),
url: OPENAI_COMPLETIONS_URL.to_string(),
api_key: std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
api_key: "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".to_string(),
model: "gpt-3.5-turbo".to_string(),
emedding_model: "text-embedding-ada-002".to_string(),
temperature: 1.0,
top_p: 1.0,
stream: false,
max_tokens: 1024u16,
response_format: None,
}
}
}
Expand Down Expand Up @@ -215,6 +250,11 @@ impl OpenAI {
self
}

pub fn with_response_format(mut self, response_format: Option<ResponseFormat>) -> Self {
self.response_format = response_format;
self
}

/// Generate a request for the OpenAI API and set the parameters
pub fn generate_request(&self, messages: &[Message]) -> Result<reqwest::Request> {
let payload = Payload {
Expand All @@ -225,13 +265,18 @@ impl OpenAI {
stop: None,
messages: messages.to_vec(),
stream: self.stream,
response_format: self.response_format.clone(),
};

println!("Payload {:#?}", payload);
let req = self
.client
.post(&self.url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.build()?;

println!("Payload {:#?}", req);
Ok(req)
}

Expand Down Expand Up @@ -260,7 +305,16 @@ impl LLM for OpenAI {
async fn generate(&self, prompt: Box<dyn Prompt>) -> Result<LLMResponse> {
let messages = prompt.to_chat()?;
let req = self.generate_request(messages.to_vec_ref())?;
println!("<<<<< {:#?}", req);
let res = self.client.execute(req).await?;

println!("<<<<< {:#?}", res);
if let Some(response_format) = &self.response_format {
if response_format.type_ == "json" {
let res = res.json::<JsonModeResponse>().await?;
return Ok(res.into());
}
}
let res = res.json::<Response>().await?;
Ok(res.into())
}
Expand Down Expand Up @@ -363,6 +417,35 @@ mod test {
let mut context = HashMap::new();
context.insert("country1", "France");
context.insert("country2", "Germany");
let prompt = template!(
"my template",
r#"
{{#chat}}
{{#user}}
What is the capital of {{country1}}?
{{/user}}
{{#assistant}}
Paris
{{/assistant}}
{{#user}}
What is the capital of {{country2}}?
{{/user}}
{{/chat}}
"#
);
let prompt = prompt.render_context("my template", &context).unwrap();
let response = client.generate(prompt).await.unwrap();
assert!(response.to_string().to_lowercase().contains("berlin"));
}

#[tokio::test]
async fn test_generate_json_mode() {
let client = OpenAI::new().with_model("gpt-3.5-turbo-1106").with_response_format(Some(ResponseFormat {
type_: "json".to_string(),
Butch78 marked this conversation as resolved.
Show resolved Hide resolved
}));
let mut context = HashMap::new();
context.insert("country1", "France");
context.insert("country2", "Germany");
let prompt = template!(
"my template",
r#"
Expand All @@ -382,6 +465,8 @@ mod test {
let prompt = prompt.render_context("my template", &context).unwrap();
let response = client.generate(prompt).await.unwrap();
assert!(response.to_string().to_lowercase().contains("berlin"));
// Assert response is a JSON object
assert!(response.to_string().starts_with("{"));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like how concrete this "{" is. It is easy to read.

There is a comment above, is something not obvious from the test name. I would C\consider an assert instead of a comment.
Something like

assert!(response.is_valid_json())

I would leave the first assert too assert!(response.to_string().starts_with("{"));

[] is a JSON. May open AI return list of dictionaries in JSON?
If so the comment is not accurate better without it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

second thought: I remember Uncle Bob's suggestion source code should be general, tests be concrete.
Maybe assert_eq!(response, "{ \"country\": "berlin"}")

}

#[tokio::test]
Expand Down
Loading