Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 支持新绘画模型及模型判断逻辑收敛 #291

Merged
merged 3 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: 支持新绘画模型及模型判断逻辑收敛
1. 增加绘画模型配置项,支持 dall-e-2 和 dall-e-3
2. 更正 genereateImage 拼写
3. 兼容支持 dall-e-3 返回 WebP 格式图片
4. 非 legacy 模型判断逻辑收敛,避免多处维护
  • Loading branch information
FrankCheungDev committed Nov 16, 2023
commit 75cfeba8aa5ad4494d77d9407903a566903cb545
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type Configuration struct {
BaseURL string `yaml:"base_url"`
// 使用模型
Model string `yaml:"model"`
// 使用绘画模型
ImageModel string `yaml:"image_model"`
// 会话超时时间
SessionTimeout time.Duration `yaml:"session_timeout"`
// 最大问题长度
Expand Down
61 changes: 39 additions & 22 deletions pkg/chatgpt/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ import (
"encoding/gob"
"errors"
"fmt"

"github.com/chai2010/webp"
"image"
_ "image/gif"
_ "image/jpeg"
"image/png"

"os"
"strings"
"time"
Expand Down Expand Up @@ -137,6 +143,22 @@ func (c *ChatContext) SetPreset(preset string) {
c.preset = preset
}

// 通过 base64 编码字符串开头字符判断图像类型
func getImageTypeFromBase64(base64Str string) string {
switch {
case strings.HasPrefix(base64Str, "/9j/"):
return "JPEG"
case strings.HasPrefix(base64Str, "iVBOR"):
return "PNG"
case strings.HasPrefix(base64Str, "R0lG"):
return "GIF"
case strings.HasPrefix(base64Str, "UklG"):
return "WebP"
default:
return "Unknown"
}
}

func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
question = question + "."
if tokenizer.MustCalToken(question) > c.maxQuestionLen {
Expand Down Expand Up @@ -181,20 +203,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
if public.Config.AzureOn {
userId = ""
}
if model == openai.GPT432K0613 ||
model == openai.GPT432K0314 ||
model == openai.GPT432K ||
model == openai.GPT40613 ||
model == openai.GPT40314 ||
model == openai.GPT4TurboPreview ||
model == openai.GPT4VisionPreview ||
model == openai.GPT4 ||
model == openai.GPT3Dot5Turbo1106 ||
model == openai.GPT3Dot5Turbo0613 ||
model == openai.GPT3Dot5Turbo0301 ||
model == openai.GPT3Dot5Turbo16K ||
model == openai.GPT3Dot5Turbo16K0613 ||
model == openai.GPT3Dot5Turbo {
if isModelSupportedChatCompletions(model) {
req := openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
Expand Down Expand Up @@ -248,14 +257,13 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
return resp.Choices[0].Text, nil
}
}
func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
func (c *ChatGPT) GenerateImage(ctx context.Context, prompt string) (string, error) {
model := public.Config.Model
if model == openai.GPT3Dot5Turbo || model == openai.GPT3Dot5Turbo0301 || model == openai.GPT3Dot5Turbo0613 ||
model == openai.GPT3Dot5Turbo16K || model == openai.GPT3Dot5Turbo16K0613 ||
model == openai.GPT4 || model == openai.GPT40314 || model == openai.GPT40613 ||
model == openai.GPT432K || model == openai.GPT432K0314 || model == openai.GPT432K0613 {
imageModel := public.Config.ImageModel
if isModelSupportedChatCompletions(model) {
req := openai.ImageRequest{
Prompt: prompt,
Model: imageModel,
Size: openai.CreateImageSize1024x1024,
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
N: 1,
Expand All @@ -271,9 +279,18 @@ func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, err
}

r := bytes.NewReader(imgBytes)
imgData, err := png.Decode(r)
if err != nil {
return "", err

// dall-e-3 返回的是 WebP 格式的图片,需要判断处理
imgType := getImageTypeFromBase64(respBase64.Data[0].B64JSON)
var imgData image.Image
var imgErr error
if imgType == "WebP" {
imgData, imgErr = webp.Decode(r)
} else {
imgData, _, imgErr = image.Decode(r)
}
if imgErr != nil {
return "", imgErr
}

imageName := time.Now().Format("20060102-150405") + ".png"
Expand Down
2 changes: 1 addition & 1 deletion pkg/chatgpt/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func ImageQa(ctx context.Context, question, userId string) (answer string, err e
// 使用重试策略进行重试
err = retry.Do(
func() error {
answer, err = chat.GenreateImage(ctx, question)
answer, err = chat.GenerateImage(ctx, question)
if err != nil {
return err
}
Expand Down
29 changes: 29 additions & 0 deletions pkg/chatgpt/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package chatgpt

import openai "github.com/sashabaranov/go-openai"

var ModelsSupportChatCompletions = []string{
openai.GPT432K0613,
openai.GPT432K0314,
openai.GPT432K,
openai.GPT40613,
openai.GPT40314,
openai.GPT4TurboPreview,
openai.GPT4VisionPreview,
openai.GPT4,
openai.GPT3Dot5Turbo1106,
openai.GPT3Dot5Turbo0613,
openai.GPT3Dot5Turbo0301,
openai.GPT3Dot5Turbo16K,
openai.GPT3Dot5Turbo16K0613,
openai.GPT3Dot5Turbo,
}

func isModelSupportedChatCompletions(model string) bool {
for _, m := range ModelsSupportChatCompletions {
if m == model {
return true
}
}
return false
}