-
Notifications
You must be signed in to change notification settings - Fork 20
/
keyword.py
124 lines (105 loc) · 3.95 KB
/
keyword.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
from collections import Counter
import requests
import torch
from langchain_text_splitters import RecursiveCharacterTextSplitter
ebd_model = ""
bge_embeddings = ""
files_load = ""
c_size = 200
c_overlap = 50
knowledge_base = ""
k_setting = 5
def search_keyword(question):
global knowledge_base, k_setting
keyword_counts = [chunk.count(question) for chunk in knowledge_base]
chunk_counter = Counter(dict(zip(knowledge_base, keyword_counts)))
top_chunks = chunk_counter.most_common(k_setting)
top_chunks_text = [chunk[0] for chunk in top_chunks]
text = "\n".join(top_chunks_text)
return "文件中的相关信息如下:\n" + text
class keyword_tool:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"is_enable": (["enable", "disable"], {"default": "enable"}),
"file_content": ("STRING", {"forceInput": True}),
"k": ("INT", {"default": 5}),
"chunk_size": ("INT", {"default": 200}),
"chunk_overlap": ("INT", {"default": 50}),
},
"optional": {},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("tool",)
FUNCTION = "file"
# OUTPUT_NODE = False
CATEGORY = "大模型派对(llm_party)/工具(tools)"
def file(self, file_content, k, chunk_size, chunk_overlap, is_enable="enable"):
if is_enable == "disable":
return (None,)
global files_load, c_size, c_overlap, knowledge_base, k_setting
k_setting = k
c_size = chunk_size
c_overlap = chunk_overlap
files_load = file_content
if knowledge_base == "":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=c_size,
chunk_overlap=c_overlap,
)
chunks = text_splitter.split_text(files_load)
knowledge_base = chunks
output = [
{
"type": "function",
"function": {
"name": "search_keyword",
"description": "查询用户上传的文件中与用户提问相关的信息。",
"parameters": {
"type": "object",
"properties": {
"question": {"type": "string", "description": "要查询的关键词"},
},
"required": ["question"],
},
},
}
]
out = json.dumps(output, ensure_ascii=False)
return (out,)
class load_keyword:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"question": ("STRING", {"default": "question"}),
"is_enable": ("BOOLEAN", {"default": True}),
"file_content": ("STRING", {"forceInput": True}),
"k": ("INT", {"default": 5}),
"chunk_size": ("INT", {"default": 200}),
"chunk_overlap": ("INT", {"default": 50}),
},
"optional": {},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("ebd_response",)
FUNCTION = "file"
# OUTPUT_NODE = False
CATEGORY = "大模型派对(llm_party)/加载器(loader)"
def file(self, question, file_content, k, chunk_size, chunk_overlap, is_enable=True):
if is_enable == False:
return (None,)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chunks = text_splitter.split_text(file_content)
keyword_counts = [chunk.count(question) for chunk in chunks]
chunk_counter = Counter(dict(zip(chunks, keyword_counts)))
top_chunks = chunk_counter.most_common(k)
top_chunks_text = [chunk[0] for chunk in top_chunks]
text = "\n".join(top_chunks_text)
output = "文件中的相关信息如下:\n" + text
return (output,)