-
-
Notifications
You must be signed in to change notification settings - Fork 633
/
mrkl.go
155 lines (131 loc) · 4.08 KB
/
mrkl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package agents
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)
const (
_finalAnswerAction = "Final Answer:"
_defaultOutputKey = "output"
)
// OneShotZeroAgent is a struct that represents an agent responsible for deciding
// what to do or give the final output if the task is finished given a set of inputs
// and previous steps taken.
//
// This agent is optimized to be used with LLMs.
type OneShotZeroAgent struct {
// Chain is the chain used to call with the values. The chain should have an
// input called "agent_scratchpad" for the agent to put its thoughts in.
Chain chains.Chain
// Tools is a list of the tools the agent can use.
Tools []tools.Tool
// Output key is the key where the final output is placed.
OutputKey string
// CallbacksHandler is the handler for callbacks.
CallbacksHandler callbacks.Handler
}
var _ Agent = (*OneShotZeroAgent)(nil)
// NewOneShotAgent creates a new OneShotZeroAgent with the given LLM model, tools,
// and options. It returns a pointer to the created agent. The opts parameter
// represents the options for the agent.
func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneShotZeroAgent {
options := mrklDefaultOptions()
for _, opt := range opts {
opt(&options)
}
return &OneShotZeroAgent{
Chain: chains.NewLLMChain(
llm,
options.getMrklPrompt(tools),
chains.WithCallback(options.callbacksHandler),
),
Tools: tools,
OutputKey: options.outputKey,
CallbacksHandler: options.callbacksHandler,
}
}
// Plan decides what action to take or returns the final result of the input.
func (a *OneShotZeroAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
fullInputs[key] = value
}
fullInputs["agent_scratchpad"] = constructMrklScratchPad(intermediateSteps)
fullInputs["today"] = time.Now().Format("January 02, 2006")
var stream func(ctx context.Context, chunk []byte) error
if a.CallbacksHandler != nil {
stream = func(ctx context.Context, chunk []byte) error {
a.CallbacksHandler.HandleStreamingFunc(ctx, chunk)
return nil
}
}
output, err := chains.Predict(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStreamingFunc(stream),
)
if err != nil {
return nil, nil, err
}
return a.parseOutput(output)
}
func (a *OneShotZeroAgent) GetInputKeys() []string {
chainInputs := a.Chain.GetInputKeys()
// Remove inputs given in plan.
agentInput := make([]string, 0, len(chainInputs))
for _, v := range chainInputs {
if v == "agent_scratchpad" || v == "today" {
continue
}
agentInput = append(agentInput, v)
}
return agentInput
}
func (a *OneShotZeroAgent) GetOutputKeys() []string {
return []string{a.OutputKey}
}
func (a *OneShotZeroAgent) GetTools() []tools.Tool {
return a.Tools
}
func constructMrklScratchPad(steps []schema.AgentStep) string {
var scratchPad string
if len(steps) > 0 {
for _, step := range steps {
scratchPad += "\n" + step.Action.Log
scratchPad += "\nObservation: " + step.Observation + "\n"
}
}
return scratchPad
}
func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
if strings.Contains(output, _finalAnswerAction) {
splits := strings.Split(output, _finalAnswerAction)
return nil, &schema.AgentFinish{
ReturnValues: map[string]any{
a.OutputKey: splits[len(splits)-1],
},
Log: output,
}, nil
}
r := regexp.MustCompile(`Action:\s*(.+)\s*Action Input:\s(?s)*(.+)`)
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
}
return []schema.AgentAction{
{Tool: strings.TrimSpace(matches[1]), ToolInput: strings.TrimSpace(matches[2]), Log: output},
}, nil, nil
}