Skip to content

Commit

Permalink
feat(context): conversation context struct
Browse files Browse the repository at this point in the history
  • Loading branch information
solywsh committed Dec 9, 2022
1 parent 74479e0 commit 07415f4
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 27 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ import (
)

func main() {
chat := chatgpt.New("YOUR_API_KEY")
chat := New("", "", 10*time.Second)
defer chat.Close()
question := "你认为2022年世界杯的冠军是谁?"
select {
case <-chat.GetTimeOutChan():
fmt.Println("time out")
}
question := "中国在欧洲\n"
fmt.Printf("Q: %s\n", question)
answer, err := chat.Chat(question)
if err != nil {
Expand Down
68 changes: 43 additions & 25 deletions chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,50 @@ import (
)

type ChatGPT struct {
client *gogpt.Client
ctx context.Context
userId string
maxToken int
timeOut time.Duration
timeOutChan chan struct{}
cancel func()
client *gogpt.Client
ctx context.Context
userId string
maxQuestionLen int
maxText int
maxAnswerLen int
timeOut time.Duration // 超时时间, 0表示不超时
timeOutChan chan struct {
}
cancel func()
stopFlag []string // 如果不设置机器会联想并返回联想结束语句
maxStopFlagLen int

ChatContext *ChatContext
}

func New(ApiKey, UserId string, timeOut time.Duration) *ChatGPT {
ctx, cancel := context.WithTimeout(context.Background(), timeOut)
var ctx context.Context
var cancel func()
if timeOut == 0 {
ctx, cancel = context.WithCancel(context.Background())
} else {
ctx, cancel = context.WithTimeout(context.Background(), timeOut)
}
timeOutChan := make(chan struct{}, 1)
go func() {
<-ctx.Done()
timeOutChan <- struct{}{}
timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
}()
return &ChatGPT{
client: gogpt.NewClient(ApiKey),
ctx: ctx,
userId: UserId,
maxToken: 1024,
timeOut: timeOut,
timeOutChan: timeOutChan,
client: gogpt.NewClient(ApiKey),
ctx: ctx,
userId: UserId,
maxQuestionLen: 1024, // 最大问题长度
maxAnswerLen: 1024, // 最大答案长度
maxText: 4096, // 最大文本 = 问题 + 回答
timeOut: timeOut,
timeOutChan: timeOutChan,
cancel: func() {
cancel()
},
stopFlag: []string{"."},
maxStopFlagLen: 1,
ChatContext: NewContext(),
}
}
func (c *ChatGPT) Close() {
Expand All @@ -43,36 +61,36 @@ func (c *ChatGPT) GetTimeOutChan() chan struct{} {
return c.timeOutChan
}

func (c *ChatGPT) SetMaxToken(maxToken int) {
if maxToken > 4096 {
maxToken = 4096
return
func (c *ChatGPT) SetMaxQuestionLen(maxQuestionLen int) {
if maxQuestionLen > c.maxText-c.maxAnswerLen {
maxQuestionLen = c.maxText - c.maxAnswerLen
}
c.maxToken = maxToken
c.maxQuestionLen = maxQuestionLen
}

func (c *ChatGPT) Chat(question string) (answer string, err error) {
if len(question)+c.maxToken > 4096 {
question = question[:4096-c.maxToken]
if len(question)+c.maxAnswerLen+c.maxStopFlagLen > c.maxText {
question = question[:c.maxText-c.maxAnswerLen]
}
req := gogpt.CompletionRequest{
Model: gogpt.GPT3TextDavinci003,
MaxTokens: c.maxToken,
Prompt: question,
MaxTokens: c.maxAnswerLen,
Prompt: question + ".", // 加"."提示AI结束
Temperature: 0.9,
TopP: 1,
N: 1,
FrequencyPenalty: 0,
PresencePenalty: 0.5,
User: c.userId,
Stop: c.stopFlag,
}
resp, err := c.client.CreateCompletion(c.ctx, req)
if err != nil {
return "", err
}
answer = resp.Choices[0].Text
for len(answer) > 0 {
if answer[0] == '\n' || answer[0] == ' ' {
if answer[:1] == "\n" || answer[0] == ' ' {
answer = answer[1:]
} else {
break
Expand Down
59 changes: 59 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package chatgpt

import "strings"

var (
DefaultHead = "\nHuman: 你好,让我们开始愉快的谈话!\nAI: 我是 AI assistant ,请问你有什么问题?"
DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
DefaultRole = "The following is a conversation with Ai assistant. The assistant is" + strings.Join(DefaultCharacter, ",") + "."
)

type ChatContext struct {
Head string
Character []string
Role string
Old []conversation
New []conversation
MaxSequence int
RestartSequence string
StartSequence string
}

type conversation struct {
role string
prompt string
}

func NewContext() *ChatContext {
return &ChatContext{
Role: DefaultRole,
Head: DefaultHead,
Character: DefaultCharacter,
Old: []conversation{},
New: []conversation{},
MaxSequence: 10,
RestartSequence: "\nHuman: ",
StartSequence: "\nAI: ",
}
}

func (c *ChatGPT) ChatWithContext(question string) {
//promptTable := strings.Builder{}
//promptTable.WriteString(c.ChatContext.Role)
//promptTable.WriteString("\n")
//promptTable.WriteString(c.ChatContext.Head)
//promptTable.WriteString("\n")
// 性能去他妈
var promptTable []string
promptTable = append(promptTable, c.ChatContext.Role)
promptTable = append(promptTable, c.ChatContext.Head)
for _, v := range c.ChatContext.Old {
promptTable = append(promptTable, v.role+" "+v.prompt)
}
promptTable = append(promptTable, c.ChatContext.RestartSequence+question+".")

}

//func (c *ChatGPT) cutPrompt(prompt []string) []string {
// //extra := len(prompt) + c.ChatContext.MaxSequence + ""
//}

0 comments on commit 07415f4

Please sign in to comment.