Skip to content

Commit

Permalink
feat: add generationConfig of Google Gemini & fix: Gemini abort bug
Browse files Browse the repository at this point in the history
  • Loading branch information
LyuLumos committed Jan 29, 2024
1 parent 1c019ea commit b03d041
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/logics/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ export const callProviderHandler = async(providerId: string, payload: HandlerPay
baseUrl: payload.globalSettings?.baseUrl,
model: payload.globalSettings?.model,
maxTokens: payload.globalSettings?.maxTokens,
maxOutputTokens: payload.globalSettings?.maxOutputTokens,
temperature: payload.globalSettings?.temperature,
top_p: payload.globalSettings?.top_p,
topP: payload.globalSettings?.topP,
topK: payload.globalSettings?.topK,
},
botSettings: payload.botSettings,
})
Expand Down
2 changes: 2 additions & 0 deletions src/providers/google/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export interface GoogleFetchPayload {
stream: boolean
body: Record<string, any>
model?: string
signal?: AbortSignal
}

export const fetchChatCompletion = async(payload: GoogleFetchPayload) => {
Expand All @@ -11,6 +12,7 @@ export const fetchChatCompletion = async(payload: GoogleFetchPayload) => {
headers: { 'Content-Type': 'application/json' },
method: 'POST',
body: JSON.stringify({ ...body }),
signal: payload.signal,
}
return fetch(`https://generativelanguage.googleapis.com/v1beta/models/${model}:streamGenerateContent?${stream ? 'alt=sse&' : ''}key=${apiKey}`, initOptions)
}
12 changes: 12 additions & 0 deletions src/providers/google/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, gl
globalSettings: {
...globalSettings,
model: 'gemini-pro',
temperature: 0.4,
maxTokens: 10240,
maxOutputTokens: 1024,
topP: 0.8,
topK: 1,
},
botSettings: {},
prompt,
Expand Down Expand Up @@ -57,7 +62,14 @@ export const handleChatCompletion = async(payload: HandlerPayload, signal?: Abor
stream,
body: {
contents: parseMessageList(messages),
generationConfig: {
temperature: payload.globalSettings.temperature as number,
maxOutputTokens: payload.globalSettings.maxOutputTokens as number,
topP: payload.globalSettings.topP as number,
topK: payload.globalSettings.topK as number,
}
},
signal,
model: payload.globalSettings.model as string,
})

Expand Down
42 changes: 41 additions & 1 deletion src/providers/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,47 @@ const providerGoogle = () => {
type: 'slider',
min: 0,
max: 32768,
default: 2048,
default: 10240,
step: 1,
},
{
key: 'maxOutputTokens',
name: 'Max Output Tokens',
description: 'Specifies the maximum number of tokens that can be generated in the response. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.',
type: 'slider',
min: 0,
max: 4096,
default: 1024,
step: 1,
},
{
key: 'temperature',
name: 'Temperature',
description: 'The temperature controls the degree of randomness in token selection. ower temperatures are good for prompts that require a more deterministic or less open-ended response.',
type: 'slider',
min: 0,
max: 1,
default: 0.4,
step: 0.01,
},
{
key: 'topP',
name: 'Top P',
description: 'An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.',
type: 'slider',
min: 0,
max: 1,
default: 0.95,
step: 0.01,
},
{
key: 'topK',
name: 'Top K',
description: 'Top K sampling chooses from the K most likely tokens.',
type: 'slider',
min: 0,
max: 32768,
default: 1,
step: 1,
},
{
Expand Down

0 comments on commit b03d041

Please sign in to comment.