-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
229 lines (194 loc) · 7.16 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from PIL import Image
from dotenv import load_dotenv
import pandas as pd
import shutil
import openai
import os
import streamlit as st
import sys
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
load_dotenv()
ss = st.session_state
st.markdown(
"""
<style>
[data-testid="stSidebar"][aria-expanded="true"]{
min-width: 450px;
max-width: 600px;
}
""",
unsafe_allow_html=True,
)
def on_api_key_change():
api_key = ss.get("api_key") or os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = api_key
openai.api_key = api_key
def save_uploadfile(uploadedfile):
dirpath = os.path.join("data", "lit_dir")
if os.path.exists(dirpath):
shutil.rmtree(dirpath)
os.makedirs(dirpath)
with open(os.path.join(dirpath, uploadedfile.name), "wb") as f:
f.write(uploadedfile.getbuffer())
st.write(
"## Xpert AI: Extract human interpretable structure-property relationships from raw data"
)
st.write(
"""XpertAI trains a surrogate model to your dataset and extracts impactful features from your dataset using XAI tools.
Currently, GPT-4 model is used to generate natural language explanations."""
)
def run_autofill():
st.session_state.auto_target = "toxicity of small molecules"
st.session_state.auto_df = "tests/toxicity_sample_data.csv"
st.experimental_rerun()
auto_target = st.session_state.get("auto_target", None)
auto_arxiv = st.session_state.get("auto_arxiv", None)
with st.sidebar:
logo = Image.open("assets/logo_2.png")
st.image(logo)
# st.markdown('# Setup your inputs!')
# Input OpenAI api key
st.markdown("### First input your OpenAI API key :key:")
api_key = st.text_input(
"OpenAI API key",
type="password",
key="api_key",
on_change=on_api_key_change,
label_visibility="hidden",
)
st.markdown("### Now upload your input dataset")
input_file = st.file_uploader(
"Dataset with featurized inputs & labels. Must have .csv extension AND the label column must be the last column of your dataset!"
)
st.markdown("### What is the target property you want to explain?")
observation = st.text_input("eg.: Toxicity of small molecules", value=auto_target)
st.markdown("### Set up XAI workflow")
mode_type = st.radio(
"1. Select the model type",
[
"Classifier",
"Regressor",
],
captions=["For predicting discreet labels", "For predicting continuous values"],
)
XAI_tool = st.radio("2. What's your favorite XAI method?", ["SHAP", "LIME", "Both"])
top_k = st.slider(
"3. Select the max number of features to be explained!", 0, 10, value=2
)
st.markdown(
"### Provide literature to generate scientific explanations! \nIf you don't provide literature, you will receive an explanation based on XAI tools."
)
lit_files = st.file_uploader(
"Upload your literature library here (Suggested):", accept_multiple_files=True
)
arxiv_keywords = st.text_input(
"If you want to scrape arxiv, provide keywords for arxiv scraping:",
help="organic molecules, solubility of small molecules",
value=auto_arxiv,
)
max_papers = st.number_input(
"Maximum number of papers to download from arxiv.org", value=15
)
button = st.button("Generate Explanation", type="primary")
st.markdown(
"## Not sure what to do? Run a test case - explaining toxicity of small molecules!"
)
st.markdown(
"""**Make sure to add your OpenAPI key**.
You can download the input dataset after the explanation is generated.
Literature parsing is not used here."""
)
auto_button = st.button("Test Run", on_click=run_autofill)
# Main page
##set up the inputs
if auto_button:
input_file = "./tests/toxicity_sample_data.csv"
df_init = pd.read_csv(input_file, header=0)
arg_dict_xai = {
"df_init": df_init,
"model_type": "Classifier",
"top_k": top_k,
"XAI_tool": XAI_tool,
}
elif input_file and button:
df_init = pd.read_csv(input_file, header=0)
arg_dict_xai = {
"df_init": df_init,
"model_type": mode_type,
"top_k": top_k,
"XAI_tool": XAI_tool,
}
else:
arg_dict_xai = None
if button or auto_button:
# validate api key
if api_key.startswith("sk-"):
from xpertai.tools.explain_model import get_modelsummary
from xpertai.tools.scrape_arxiv import scrape_arxiv
from xpertai.tools.generate_nle import gen_nle
from xpertai.tools.utils import vector_db
else:
st.warning("Please enter a valid OpenAI API key!")
st.stop()
if arg_dict_xai is None:
st.warning("Please upload a dataset!")
st.stop()
explanation = get_modelsummary(arg_dict_xai)
st.markdown("### XAI Analysis:")
xg_plot = Image.open(f"./data/figs/xgbmodel_error.png")
st.image(xg_plot)
if XAI_tool in ["SHAP", "LIME"]:
st.image(Image.open(f"./data/figs/{XAI_tool.lower()}_bar.png"))
else:
st.image(Image.open(f"./data/figs/shap_bar.png"))
st.image(Image.open(f"./data/figs/lime_bar.png"))
if auto_button:
shutil.copytree("./paper/datasets", "./data/figs", dirs_exist_ok=True)
with st.spinner("Please wait...:computer: :speech_balloon:"):
# read literature
if lit_files:
for file in lit_files:
save_uploadfile(file)
try:
vector_db(
lit_file=os.path.join("./data/lit_dir", file.name),
try_meta_data=True,
)
except BaseException:
st.write("vectordb failed!!")
# scrape arxiv.org
elif arxiv_keywords:
arg_dict_arxiv = {"key_words": arxiv_keywords, "max_papers": max_papers}
scrape_arxiv(arg_dict_arxiv)
elif not arxiv_keywords and not lit_files:
st.markdown(
f"""### Literature is not provided to make an informed explanation. Based on XAI analysis, the following explanation can be given:
\n{explanation}"""
)
nle = explanation
else:
# Generate evidence-based explanation
nle = gen_nle(
{
"observation": observation,
"top_k": top_k,
"XAI_tool": XAI_tool,
}
)
st.write(
"### The structure function relationship based on XAI analysis and literature, the following explanation can be given:\n",
nle,
)
f = open("./data/figs/structure_function_relationship.txt", "w+")
f.write(f"Understanding {observation}\n:")
f.write(nle)
f.write(
"\n\nExplanation generated with XpertAI. https://github.com/geemi725/XpertAI"
)
f.close()
shutil.make_archive("./data/figs", "zip", "./data/figs/")
with open("./data/figs.zip", "rb") as f:
st.download_button(
"Download the outputs!", f, file_name="XpertAI_output.zip"
)