-
Notifications
You must be signed in to change notification settings - Fork 18
/
context.go
116 lines (103 loc) · 3.29 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package chatgpt
import (
"fmt"
gogpt "github.com/sashabaranov/go-gpt3"
"strings"
)
var (
DefaultAiRole = "AI"
DefaultHumanRole = "Human"
DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
DefaultPreset = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
)
type ChatContext struct {
background string // 对话背景
preset string // 预设对话
maxSeqTimes int // 最大对话次数
aiRole *role // AI角色
humanRole *role // 人类角色
old []conversation // 旧对话
restartSeq string // 重新开始对话的标识
startSeq string // 开始对话的标识
seqTimes int // 对话次数
}
type conversation struct {
role *role
prompt string
}
type role struct {
name string
}
func (c *ChatContext) SetHumanRole(role string) {
c.humanRole.name = role
c.restartSeq = "\n" + c.humanRole.name + ": "
}
func (c *ChatContext) SetAiRole(role string) {
c.aiRole.name = role
c.startSeq = "\n" + c.aiRole.name + ": "
}
func (c *ChatContext) SetMaxSeqTimes(times int) {
c.maxSeqTimes = times
}
func NewContext() *ChatContext {
return &ChatContext{
aiRole: &role{name: DefaultAiRole},
humanRole: &role{name: DefaultHumanRole},
background: fmt.Sprintf(DefaultBackground, strings.Join(DefaultCharacter, ", ")+"."),
maxSeqTimes: 10,
preset: fmt.Sprintf(DefaultPreset, DefaultHumanRole, DefaultAiRole),
old: []conversation{},
seqTimes: 0,
restartSeq: "\n" + DefaultHumanRole + ": ",
startSeq: "\n" + DefaultAiRole + ": ",
}
}
func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
return "", OverMaxSequenceTimes
}
var promptTable []string
promptTable = append(promptTable, c.ChatContext.background)
promptTable = append(promptTable, c.ChatContext.preset)
for _, v := range c.ChatContext.old {
if v.role == c.ChatContext.humanRole {
promptTable = append(promptTable, "\n"+v.role.name+": "+v.prompt)
} else {
promptTable = append(promptTable, v.role.name+": "+v.prompt)
}
}
promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question+".")
prompt := strings.Join(promptTable, "\n")
prompt += c.ChatContext.startSeq
if len(prompt) > c.maxText-c.maxAnswerLen {
return "", OverMaxTextLength
}
req := gogpt.CompletionRequest{
Model: gogpt.GPT3TextDavinci003,
MaxTokens: c.maxAnswerLen,
Prompt: prompt,
Temperature: 0.9,
TopP: 1,
N: 1,
FrequencyPenalty: 0,
PresencePenalty: 0.5,
User: c.userId,
Stop: []string{c.ChatContext.aiRole.name + ":", c.ChatContext.humanRole.name + ":"},
}
resp, err := c.client.CreateCompletion(c.ctx, req)
if err != nil {
return "", err
}
formatAnswer(resp.Choices[0].Text)
c.ChatContext.old = append(c.ChatContext.old, conversation{
role: c.ChatContext.humanRole,
prompt: question,
})
c.ChatContext.old = append(c.ChatContext.old, conversation{
role: c.ChatContext.aiRole,
prompt: resp.Choices[0].Text,
})
c.ChatContext.seqTimes++
return resp.Choices[0].Text, nil
}