-
Notifications
You must be signed in to change notification settings - Fork 7
/
app_conv.py
148 lines (118 loc) · 4.83 KB
/
app_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import streamlit as st
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
from src import CFG, logger
from src.retrieval_qa import (
build_rerank_retriever,
build_condense_question_chain,
build_question_answer_chain,
)
from src.vectordb import build_vectordb, delete_vectordb, load_faiss, load_chroma
from streamlit_app.utils import perform, load_base_embeddings, load_llm, load_reranker
TITLE = "Conversational RAG"
st.set_page_config(page_title=TITLE)
LLM = load_llm()
BASE_EMBEDDINGS = load_base_embeddings()
RERANKER = load_reranker()
VECTORDB_PATH = CFG.VECTORDB[0].PATH
CONDENSE_QUESTION_CHAIN = build_condense_question_chain(LLM)
QA_CHAIN = build_question_answer_chain(LLM)
@st.cache_resource
def load_vectordb():
if CFG.VECTORDB_TYPE == "faiss":
return load_faiss(BASE_EMBEDDINGS, VECTORDB_PATH)
if CFG.VECTORDB_TYPE == "chroma":
return load_chroma(BASE_EMBEDDINGS, VECTORDB_PATH)
raise NotImplementedError
def init_chat_history():
"""Initialise chat history."""
clear_button = st.sidebar.button("Clear Chat", key="clear")
if clear_button or "chat_history" not in st.session_state:
st.session_state["chat_history"] = list()
st.session_state["display_history"] = [("", "Hello! How can I help you?", None)]
def print_docs(source_documents):
for row in source_documents:
if row.metadata.get("page_number"):
st.write(f"**Page {row.metadata['page_number']}**")
st.info(_format_text(row.page_content))
def _format_text(text):
return text.replace("$", r"\$")
def doc_conv_qa():
with st.sidebar:
st.title(TITLE)
with st.expander("Models used"):
st.info(f"LLM: `{CFG.LLM_PATH}`")
st.info(f"Embeddings: `{CFG.EMBEDDINGS_PATH}`")
st.info(f"Reranker: `{CFG.RERANKER_PATH}`")
uploaded_file = st.file_uploader("Upload a PDF and build VectorDB", type=["pdf"])
if st.button("Build VectorDB"):
if uploaded_file is None:
st.error("No PDF uploaded")
st.stop()
if os.path.exists(VECTORDB_PATH):
st.warning("Deleting existing VectorDB")
delete_vectordb(VECTORDB_PATH, CFG.VECTORDB_TYPE)
with st.spinner("Building VectorDB..."):
perform(
build_vectordb,
uploaded_file.read(),
embedding_function=BASE_EMBEDDINGS,
)
load_vectordb.clear()
if not os.path.exists(VECTORDB_PATH):
st.info("Please build VectorDB first.")
st.stop()
try:
with st.status("Load retrieval chain", expanded=False) as status:
st.write("Loading retrieval chain...")
vectordb = load_vectordb()
RETRIEVER = build_rerank_retriever(vectordb, RERANKER)
status.update(label="Loading complete!", state="complete", expanded=False)
st.success("Reading from existing VectorDB")
except Exception as e:
st.error(e)
st.stop()
st.sidebar.write("---")
init_chat_history()
# Display chat history
for question, answer, source_documents in st.session_state.display_history:
if question != "":
with st.chat_message("user"):
st.markdown(question)
with st.chat_message("assistant"):
st.markdown(answer)
if source_documents is not None:
with st.expander("Sources"):
print_docs(source_documents)
if user_query := st.chat_input("Your query"):
with st.chat_message("user"):
st.markdown(user_query)
if user_query is not None:
with st.chat_message("assistant"):
st_callback = StreamlitCallbackHandler(
parent_container=st.container(),
expand_new_thoughts=True,
collapse_completed_thoughts=True,
)
question = CONDENSE_QUESTION_CHAIN.invoke(
{
"question": user_query,
"chat_history": st.session_state.chat_history,
},
)
logger.info(question)
source_documents = RETRIEVER.invoke(question)
answer = QA_CHAIN.invoke(
{
"context": source_documents,
"question": question,
},
config={"callbacks": [st_callback]},
)
st.success(_format_text(answer))
with st.expander("Sources"):
print_docs(source_documents)
st.session_state.chat_history.append((user_query, answer))
st.session_state.display_history.append((user_query, answer, source_documents))
if __name__ == "__main__":
doc_conv_qa()