Skip to content

Commit

Permalink
feat: 添加兼容支持 accesstoken 方式
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzonggui committed Mar 29, 2023
1 parent 8f86a95 commit d340fd6
Show file tree
Hide file tree
Showing 12 changed files with 419 additions and 28 deletions.
1 change: 1 addition & 0 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rust-version = "1.59"
tauri-build = { version = "1.2.1", features = [] }

[dependencies]
uuid = "1.3.0"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
tauri = { version = "1.2.4", features = ["dialog-all", "fs-all", "http-all", "os-all", "path-all", "process-exit", "process-relaunch", "shell-open", "updater", "window-close", "window-hide", "window-show", "window-start-dragging"] }
Expand Down
198 changes: 198 additions & 0 deletions src-tauri/src/app/cmd/gpt_access_token.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use eventsource_stream::{EventStreamError, Eventsource};
use futures::TryStreamExt;
use log::{error, info};
use reqwest;
use serde::{ser::Serializer, Deserialize, Serialize};
use serde_json::{json, Value};
use std::{env::consts::OS, time::Duration};
use tauri::{AppHandle, Manager};
use uuid::Uuid;

type Result<T> = std::result::Result<T, Error>;

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Request(#[from] reqwest::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Stream(#[from] EventStreamError<reqwest::Error>),
#[error("Custom Error: (code: {code:?}, message: {msg:?})")]
Custom { code: u16, msg: String },
}

impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ProgressPayload {
pub id: u64,
pub detail: String,
pub role: String,
pub finish_reason: String,
pub conversation_id: Option<String>,
pub parent_message_id: String,
}

impl ProgressPayload {
pub fn emit_progress(&self, handle: &AppHandle) {
handle.emit_all("CHAT_FETCHEING_PROGRESS", &self).ok();
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub role: String,
pub content: String,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[allow(non_snake_case)]
pub struct FetchOption {
pub proxy: Option<String>,
pub host: String,
pub apiKey: String,
pub model: String,
pub temperature: f32,

pub conversationId: Option<String>,
pub parentMessageId: Option<String>,
pub messageId: Option<String>,
pub action: Option<String>,
pub timeoutMs: Option<String>,
// pub onProgress?: (partialResponse: ChatMessage) => void
// pub abortSignal?: AbortSignal
}

#[tauri::command]
pub async fn fetch_chat_api_by_access_token(
handle: AppHandle,
id: u64,
messages: Vec<Message>,
option: FetchOption,
) -> Result<u64> {
// https://platform.openai.com/docs/guides/chat/introduction
// let url = "https://api.openai.com/v1/chat/completions";
log::info!(
"> send message: length: {}, option: {:?}",
messages.len(),
option,
);

let _messages_id = option.messageId.unwrap_or(Uuid::new_v4().to_string());
let _parent_message_id = option.parentMessageId.unwrap_or(Uuid::new_v4().to_string());
let conversation_id = option.conversationId;
let action = option.action.unwrap_or("next".to_string());

let last_message = messages.last().unwrap();

let mut body = json!({
"action": action,
"messages": [{
"id": _messages_id,
"role": "user".to_string(),
"content": {
"content_type": "text".to_string(),
"parts": [last_message.content]
}
}],
"model": option.model,
"parent_message_id": _parent_message_id,
});

info!("> conversation_id: {:?}", conversation_id);
if let Some(conversation_id) = conversation_id {
body["conversation_id"] = conversation_id.into();
}
log::info!("> send message: body {}", body);

let proxy_str = option.proxy.unwrap_or(String::from(""));

let client: reqwest::Client = {
log::info!("proxy is: {}", proxy_str);
let mut client_builder = reqwest::Client::builder();
if proxy_str.len() > 0 {
let proxy = reqwest::Proxy::all(proxy_str).unwrap();
client_builder = client_builder.proxy(proxy);
}
client_builder.build().unwrap()
};
info!("> body body: {}", body);

let res = client
.post(option.host)
.header("Accept", "text/event-stream")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", option.apiKey))
.header(
reqwest::header::USER_AGENT,
format!("ChatGPT-Tauri ({})", OS),
)
.timeout(Duration::from_secs(600))
.body(body.to_string())
.send()
.await?;
info!("> receive message: {}", id);

let status_code = res.status().as_u16();
info!("> receive message status code: {}", status_code);
if status_code != 200 {
let error_msg = res.text().await?;
log::error!("{}", error_msg);
return Err(Error::Custom {
code: status_code,
msg: String::from(error_msg),
});
}

let mut stream = res.bytes_stream().eventsource();
while let Some(chunk) = stream.try_next().await? {
let chunk = chunk.data;
if chunk == "[DONE]" {
return Ok(id);
} else {
match serde_json::from_str::<Value>(&chunk) {
Ok(object) => {
// info!("> object: {:?}", object);
let _message = &object["message"];
let _conversation_id =
String::from(object["conversation_id"].as_str().unwrap_or("")); // 从 JSON 对象获取 conversationId
let content =
String::from(_message["content"]["parts"][0].as_str().unwrap_or(""));
let role = String::from(_message["author"]["role"].as_str().unwrap_or(""));
let finish_reason = String::from(
_message["metadata"]["finish_details"]["type"]
.as_str()
.unwrap_or(""),
);
let progress = ProgressPayload {
id,
detail: content,
role,
finish_reason,
conversation_id: Some(_conversation_id),
parent_message_id: _messages_id.clone(),
};
// info!("> progress: {:?}", progress);
progress.emit_progress(&handle);
}
Err(err) => {
// 处理 JSON 转换错误
info!("Failed to parse JSON object: {:?}", err); // 中途会打印一个时间戳,导致无法转换为 JSON
continue; // 跳过当前循环,继续下一个循环
}
}
}
}

Ok(id)
}
129 changes: 129 additions & 0 deletions src-tauri/src/app/cmd/gpt_api_key.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use tauri::{AppHandle, Manager};
use reqwest;
use eventsource_stream::{Eventsource, EventStreamError};
use serde_json::{json, Value};
use serde::{ser::Serializer, Serialize, Deserialize};
use futures::{TryStreamExt};
use std::{ time::Duration, env::consts::OS };
use log::{error, info};

type Result<T> = std::result::Result<T, Error>;

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Request(#[from] reqwest::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Stream(#[from] EventStreamError<reqwest::Error>),
#[error("Custom Error: (code: {code:?}, message: {msg:?})")]
Custom{code: u16, msg: String}
}

impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ProgressPayload {
pub id: u64,
pub detail: String,
pub role: String,
pub finish_reason: String,
}

impl ProgressPayload {
pub fn emit_progress(&self, handle: &AppHandle) {
handle.emit_all("CHAT_FETCHEING_PROGRESS", &self).ok();
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub role: String,
pub content: String
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[allow(non_snake_case)]
pub struct FetchOption {
pub proxy: Option<String>,
pub host: String,
pub apiKey: String,
pub model: String,
pub temperature: f32,
}

#[tauri::command]
pub async fn fetch_chat_api_by_api_key(
handle: AppHandle,
id: u64,
messages: Vec<Message>,
option: FetchOption,
) -> Result<u64> {
// https://platform.openai.com/docs/guides/chat/introduction
// let url = "https://api.openai.com/v1/chat/completions";
let data = json!({
"model": option.model,
"messages": messages,
"temperature": option.temperature,
"stream": true,
});
log::info!("> send message: length: {}, option: {:?},", messages.len(), option);
let proxy_str = option.proxy.unwrap_or(String::from(""));

let client : reqwest::Client = {
log::info!("proxy is: {}", proxy_str);
let mut client_builder = reqwest::Client::builder();
if proxy_str.len()>0 {
let proxy = reqwest::Proxy::all(proxy_str).unwrap();
client_builder = client_builder.proxy(proxy);
}
client_builder.build().unwrap()
};
let res = client.post(option.host)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", option.apiKey))
.header(reqwest::header::USER_AGENT, format!("ChatGPT-Tauri ({})", OS))
.timeout(Duration::from_secs(600))
.body(data.to_string())
.send()
.await?;
info!("> receive message: {}", id);

let status_code = res.status().as_u16();
if status_code != 200 {
let error_msg = res.text().await?;
log::error!("{}", error_msg);
return Err(Error::Custom {code: status_code, msg:String::from(error_msg)})
}

let mut stream = res.bytes_stream().eventsource();
while let Some(chunk) = stream.try_next().await? {
let chunk = chunk.data;
if chunk == "[DONE]" {
return Ok(id)
} else {
let object:Value = serde_json::from_str(&chunk)?;
let delta = &object["choices"][0]["delta"];
let content = String::from(delta["content"].as_str().unwrap_or(""));
info!("> receive content: {:?}", content);
let role = String::from(delta["role"].as_str().unwrap_or(""));
let finish_reason = String::from(object["finish_reason"].as_str().unwrap_or(""));
let progress = ProgressPayload {id, detail:content, role, finish_reason};
progress.emit_progress(&handle);
}
}
Ok(id)
}



5 changes: 3 additions & 2 deletions src-tauri/src/app/cmd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod gpt;
pub mod gpt_access_token;
pub mod gpt_api_key;
pub mod download;
pub mod window;
pub mod window;
5 changes: 3 additions & 2 deletions src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ fn main() {
let mut builder = tauri::Builder::default()
.plugin(log.build())
.invoke_handler(tauri::generate_handler![
cmd::gpt::fetch_chat_api,
cmd::gpt_access_token::fetch_chat_api_by_access_token,
cmd::gpt_api_key::fetch_chat_api_by_api_key,
cmd::download::download_img,
cmd::window::new_window
])
Expand All @@ -59,7 +60,7 @@ fn main() {
}
})
}

builder.run(tauri::generate_context!())
.expect("error while running tauri application");
}
5 changes: 3 additions & 2 deletions src-tauri/tauri.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
"all": true,
"request": true,
"scope": [
"https://api.openai.com/v1/*"
"https://api.openai.com/v1/*",
"https://bypass.duti.tech/api/*"
]
},
"shell": {
Expand Down Expand Up @@ -96,4 +97,4 @@
"pubkey": "dW50cnVzdGVkIGNvbW1lbnQ6IG1pbmlzaWduIHB1YmxpYyBrZXk6IDM2OUUyQUQ5QjE1Q0FEMTEKUldRUnJWeXgyU3FlTmxOS0N0aVBhNGUwL3c3QlBIY29uMHFUdmhUZS9YNmpKNE83L1BKZ3dER2QK"
}
}
}
}

0 comments on commit d340fd6

Please sign in to comment.