-
Notifications
You must be signed in to change notification settings - Fork 94
/
tool_combine.py
148 lines (127 loc) · 4.9 KB
/
tool_combine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
def remove_duplicates(dicts):
name_to_dict = {}
for d in dicts:
if "function" in d and "name" in d["function"]:
function_name = d["function"]["name"]
name_to_dict[function_name] = d
return list(name_to_dict.values())
class tool_combine:
@classmethod
def INPUT_TYPES(s):
return {
"required": {},
"optional": {
"tool1": ("STRING", {"forceInput": True}),
"tool2": ("STRING", {"forceInput": True}),
"tool3": ("STRING", {"forceInput": True}),
"is_enable": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("tools",)
FUNCTION = "combine"
# OUTPUT_NODE = False
CATEGORY = "大模型派对(llm_party)/组合(combine)"
def combine(self, is_enable=True, tool1=None, tool2=None, tool3=None):
if is_enable == False:
return (None,)
output = []
tool_all = [tool1, tool2, tool3]
for tool in tool_all:
if tool is not None:
tool = json.loads(tool)
output.extend(tool)
if output != []:
# 保留 dict_keys 更多的 data_base_advance 工具
filtered_tools = []
max_dict_keys = 0
for tool in output:
if tool["function"]["name"] == "data_base_advance":
file_names = tool["function"]["parameters"]["properties"]["file_name"]["description"]
dict_keys_count = file_names.count("dict_keys")
if dict_keys_count > max_dict_keys:
filtered_tools = [tool]
max_dict_keys = dict_keys_count
elif dict_keys_count == max_dict_keys:
filtered_tools.append(tool)
else:
filtered_tools.append(tool)
output = filtered_tools
if output != []:
output = remove_duplicates(output)
out = json.dumps(output, ensure_ascii=False)
else:
out = None
return (out,)
class tool_combine_plus:
@classmethod
def INPUT_TYPES(s):
return {
"required": {},
"optional": {
"tool1": ("STRING", {"forceInput": True}),
"tool2": ("STRING", {"forceInput": True}),
"tool3": ("STRING", {"forceInput": True}),
"tool4": ("STRING", {"forceInput": True}),
"tool5": ("STRING", {"forceInput": True}),
"tool6": ("STRING", {"forceInput": True}),
"tool7": ("STRING", {"forceInput": True}),
"tool8": ("STRING", {"forceInput": True}),
"tool9": ("STRING", {"forceInput": True}),
"tool10": ("STRING", {"forceInput": True}),
"is_enable": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("tools",)
FUNCTION = "combine"
# OUTPUT_NODE = False
CATEGORY = "大模型派对(llm_party)/组合(combine)"
def combine(
self,
is_enable=True,
tool1=None,
tool2=None,
tool3=None,
tool4=None,
tool5=None,
tool6=None,
tool7=None,
tool8=None,
tool9=None,
tool10=None,
):
if is_enable == False:
return (None,)
output = []
tool_all = [tool1, tool2, tool3, tool4, tool5, tool6, tool7, tool8, tool9, tool10]
for tool in tool_all:
if tool:
try:
tool = json.loads(tool)
output.extend(tool)
except json.JSONDecodeError as e:
print(f"JSONDecodeError: {e}")
if output != []:
# 保留 dict_keys 更多的 data_base_advance 工具
filtered_tools = []
max_dict_keys = 0
for tool in output:
if tool["function"]["name"] == "data_base_advance":
file_names = tool["function"]["parameters"]["properties"]["file_name"]["description"]
dict_keys_count = file_names.count("dict_keys")
if dict_keys_count > max_dict_keys:
filtered_tools = [tool]
max_dict_keys = dict_keys_count
elif dict_keys_count == max_dict_keys:
filtered_tools.append(tool)
else:
filtered_tools.append(tool)
output = filtered_tools
if output != []:
output = remove_duplicates(output)
out = json.dumps(output, ensure_ascii=False)
else:
out = None
return (out,)