-
Notifications
You must be signed in to change notification settings - Fork 20
/
wikipedia.py
146 lines (127 loc) · 4.96 KB
/
wikipedia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
import torch
import wikipedia
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
ebd_model = ""
bge_embeddings = ""
files_load = ""
c_size = 200
c_overlap = 50
knowledge_base = ""
def get_wikipedia(query):
global bge_embeddings, c_size, c_overlap
if bge_embeddings == "":
# 设置语言
wikipedia.set_lang("zh")
# 获取特定页面的内容
py_page = wikipedia.page(query)
res = py_page.content[:1000]
return "维基百科上的相关信息为:\n" + res
else:
# 设置语言
wikipedia.set_lang("zh")
# 获取特定页面的内容
py_page = wikipedia.page(query)
res = py_page.content
# 创建一个文本分割器,将文本分割成多个段落
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=c_size,
chunk_overlap=c_overlap,
)
# 将文本分割成多个段落
documents = text_splitter.split_text(res)
# 创建一个向量存储
vectorstore = FAISS.from_documents(documents, bge_embeddings)
# 搜索与查询最相关的段落
query_text = query
query_embedding = bge_embeddings.embed_query(query_text)
similar_documents = vectorstore.similarity_search(query_embedding, k=5)
# 合并段落
merged_text = "\n".join([document.page_content for document in similar_documents])
# 返回合并后的文本
return "维基百科上的相关信息为:\n" + merged_text
class wikipedia_tool:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"query": ("STRING", {"default": "query"}),
"is_enable": ("BOOLEAN", {"default": True}),
"chunk_size": ("INT", {"default": 200}),
"chunk_overlap": ("INT", {"default": 50}),
"device": (
["auto", "cuda", "mps", "cpu"],
{"default": ("auto")},
),
},
"optional": {
"embedding_path": ("STRING", {"default": None}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("tool",)
FUNCTION = "wikipedia"
# OUTPUT_NODE = False
CATEGORY = "大模型派对(llm_party)/工具(tools)"
def wikipedia(self, query, embedding_path, chunk_size, chunk_overlap, device, is_enable="enable"):
if is_enable == "disable":
return (None,)
global ebd_model, files_load, bge_embeddings, c_size, c_overlap, knowledge_base
c_size = chunk_size
c_overlap = chunk_overlap
if device == "auto":
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
if ebd_model == "":
model_kwargs = {"device": device}
encode_kwargs = {"normalize_embeddings": True} # 设置为 True 以计算余弦相似度
if bge_embeddings == "" and embedding_path is not None and embedding_path != "":
bge_embeddings = HuggingFaceBgeEmbeddings(
model_name=embedding_path, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)
output = [
{
"type": "function",
"function": {
"name": "get_wikipedia",
"description": "用于查询维基百科上的相关内容",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "需要查询的关键词,例如:python",
"default": str(query),
}
},
"required": ["query"],
},
},
}
]
out = json.dumps(output, ensure_ascii=False)
return (out,)
class load_wikipedia:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"query": ("STRING", {"default": "query"}),
"is_enable": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("file_content",)
FUNCTION = "wikipedia"
# OUTPUT_NODE = False
CATEGORY = "大模型派对(llm_party)/加载器(loader)"
def wikipedia(self, query, is_enable=True):
if is_enable == False:
return (None,)
# 设置语言
wikipedia.set_lang("zh")
# 获取特定页面的内容
py_page = wikipedia.page(query)
out = py_page.content
return (out,)