Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wangcx #11

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3fb6f80
add data by wangcx.
Linmj-Judy Jan 31, 2024
86f64f6
add material evalset
TablewareBox Feb 4, 2024
e701285
Merge branch 'main' into pr/8
TablewareBox Feb 4, 2024
c079b08
cot
Naplessss Feb 6, 2024
2f47cdd
cot
Naplessss Feb 6, 2024
a6a0209
debug
Naplessss Feb 7, 2024
c8b9fe6
Merge branch 'main' into wangcx
Naplessss Feb 7, 2024
2482aa4
data
Naplessss Feb 7, 2024
1155ac1
wang
Naplessss Feb 7, 2024
36b8417
Merge branch 'main' of https://github.com/TablewareBox/evals into main
Naplessss Feb 7, 2024
7ba2083
Merge branch 'main' into wangcx
Naplessss Feb 7, 2024
91a9282
wang bug
Naplessss Feb 7, 2024
58d4fb2
wang bug
Naplessss Feb 7, 2024
24443e7
Merge branch 'main' of https://github.com/TablewareBox/evals into main
Naplessss Feb 7, 2024
ed849e8
Merge branch 'wangcx' of https://github.com/boliqq07/evals into wangcx
Naplessss Feb 7, 2024
68bc5be
Merge branch 'wangcx' into main
Naplessss Feb 7, 2024
eaaf51f
change data by wang
Naplessss Feb 8, 2024
4727f70
Merge branch 'main' of https://github.com/TablewareBox/evals into main
Naplessss Mar 5, 2024
fe0f2a5
add data
Naplessss Mar 5, 2024
b173327
Merge branch 'main' into wangcx
Naplessss Mar 5, 2024
14d82cb
Merge branch 'main' of https://github.com/TablewareBox/evals into wangcx
Naplessss Mar 6, 2024
9d51d1f
gemini change
Naplessss Mar 7, 2024
812654e
gemini change
Naplessss Mar 7, 2024
10b550c
more
Naplessss Mar 10, 2024
fd66819
Merge branch 'main' of https://github.com/TablewareBox/evals into main
Naplessss Mar 10, 2024
32ef94d
Merge branch 'main' into wangcx
Naplessss Mar 10, 2024
7f44214
merge main to wangcx
Naplessss Mar 10, 2024
5a502e0
merge main to wangcx
Naplessss Mar 10, 2024
0ca7d50
Merge branch 'main' of https://github.com/TablewareBox/evals into main
Naplessss Mar 11, 2024
ef1e1d2
by wang
Naplessss Mar 11, 2024
d27e623
Merge branch 'TablewareBox:main' into wangcx
boliqq07 Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion evals/completion_fns/uni_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo
"query": prompt,
'api_key': self.api_key
}
response = requests.post(url, json=payload, timeout=1200)
response = requests.post(url, json=payload, timeout=300)
try:
answer = response.json()['answer']
except:
Expand Down
168 changes: 90 additions & 78 deletions evals/elsuite/rag_table_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,97 +91,109 @@ def __init__(
self.instructions = instructions

def eval_sample(self, sample, rng):
assert isinstance(sample, FileSample)

prompt = \
self.instructions
# + f"\nThe fields should at least contain {sample.compare_fields}"
result = self.completion_fn(
prompt=prompt,
temperature=0.0,
max_tokens=5,
file_name=sample.file_name,
file_link=sample.file_link
)
sampled = result.get_completions()[0]

compare_fields_types = [type(x) for x in sample.compare_fields]
header_rows = [0, 1] if tuple in compare_fields_types else [0]

correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=header_rows).astype(str), compare_fields=sample.compare_fields)
correct_answer.to_csv("temp.csv", index=False)
correct_str = open("temp.csv", 'r').read()

if sample.index not in correct_answer.columns:
if len(header_rows)>1:
correct_answer.columns = pd.MultiIndex.from_tuples([sample.index] + list(correct_answer.columns)[1:])
else:
correct_answer.columns = [sample.index] + list(correct_answer.columns)[1:]

try:
if re.search(outlink_pattern, sampled) is not None:
code = re.search(outlink_pattern, sampled).group()
link = re.sub(outlink_pattern, r"\1", code)

fname = f"/tmp/LLMEvals_{uuid.uuid4()}.csv"
os.system(f"wget {link} -O {fname}")
table = pd.read_csv(fname)
if pd.isna(table.iloc[0, 0]):
table = pd.read_csv(fname, header=header_rows)
elif "csv" in prompt:
code = re.search(csv_pattern, sampled).group()
code_content = re.sub(csv_pattern, r"\1", code)
code_content_processed = parse_csv_text(code_content)
# table = pd.read_csv(StringIO(code_content_processed), header=header_rows)
table = pd.read_csv(StringIO(code_content_processed))
if pd.isna(table.iloc[0, 0]):
table = pd.read_csv(StringIO(code_content_processed), header=header_rows)

elif "json" in prompt:
code = re.search(json_pattern, sampled).group()
code_content = re.sub(json_pattern, r"\1", code).replace("\"", "")
table = pd.DataFrame(json.loads(code_content))
else:
table = pd.DataFrame()

table = parse_table_multiindex(table, compare_fields=sample.compare_fields)

if sample.index not in table.columns:
assert isinstance(sample, FileSample)

prompt = \
self.instructions
# + f"\nThe fields should at least contain {sample.compare_fields}"
result = self.completion_fn(
prompt=prompt,
temperature=0.0,
max_tokens=5,
file_name=sample.file_name,
file_link=sample.file_link
)
sampled = result.get_completions()[0]

compare_fields_types = [type(x) for x in sample.compare_fields]
header_rows = [0, 1] if tuple in compare_fields_types else [0]

correct_answer = parse_table_multiindex(pd.read_csv(sample.answerfile_name, header=header_rows).astype(str), compare_fields=sample.compare_fields)
correct_answer.to_csv("temp.csv", index=False)
correct_str = open("temp.csv", 'r').read()

if sample.index not in correct_answer.columns:
if len(header_rows)>1:
table.columns = pd.MultiIndex.from_tuples([sample.index] + list(table.columns)[1:])
correct_answer.columns = pd.MultiIndex.from_tuples([sample.index] + list(correct_answer.columns)[1:])
else:
table.columns =[sample.index] + list(table.columns)[1:]

print(table)
answerfile_out = sample.answerfile_name.replace(".csv", "_output.csv")
table.to_csv(answerfile_out, index=False)
picked_str = open(answerfile_out, 'r').read()
except:
print(Path(sample.file_name).stem)
traceback.print_exc()
correct_answer.columns = [sample.index] + list(correct_answer.columns)[1:]

try:
if re.search(outlink_pattern, sampled) is not None:
code = re.search(outlink_pattern, sampled).group()
link = re.sub(outlink_pattern, r"\1", code)

fname = f"/tmp/LLMEvals_{uuid.uuid4()}.csv"
os.system(f"wget {link} -O {fname}")
table = pd.read_csv(fname)
if pd.isna(table.iloc[0, 0]):
table = pd.read_csv(fname, header=header_rows)
elif "```csv" in prompt:
code = re.search(csv_pattern, sampled).group()
code_content = re.sub(csv_pattern, r"\1", code)
code_content_processed = parse_csv_text(code_content)
# table = pd.read_csv(StringIO(code_content_processed), header=header_rows)
table = pd.read_csv(StringIO(code_content_processed))
if pd.isna(table.iloc[0, 0]):
table = pd.read_csv(StringIO(code_content_processed), header=header_rows)

elif "```json" in prompt:
code = re.search(json_pattern, sampled).group()
code_content = re.sub(json_pattern, r"\1", code).replace("\"", "")
table = pd.DataFrame(json.loads(code_content))
else:
table = pd.DataFrame()

table = parse_table_multiindex(table, compare_fields=sample.compare_fields)

if sample.index not in table.columns:
if len(header_rows)>1:
table.columns = pd.MultiIndex.from_tuples([sample.index] + list(table.columns)[1:])
else:
table.columns =[sample.index] + list(table.columns)[1:]

print(table)
print(correct_answer)
answerfile_out = sample.answerfile_name.replace(".csv", "_output.csv")
table.to_csv(answerfile_out, index=False)
picked_str = open(answerfile_out, 'r').read()
except:
print(Path(sample.file_name).stem)
traceback.print_exc()
record_match(
prompt=prompt,
correct=False,
expected=correct_str,
picked=sampled,
file_name=sample.file_name,
jobtype="match_all"
)
table = None
picked_str = "Failed to parse"

metrics = tableMatching(correct_answer, table, index=sample.index, compare_fields=sample.compare_fields,
record=False, file_name=sample.file_name)
record_match(
prompt=prompt,
correct=False,
correct=(metrics["recall_field"] == 1.0 and metrics["recall_index"] == 1.0 and metrics["recall_value"] == 1.0),
expected=correct_str,
picked=sampled,
picked=picked_str,
file_name=sample.file_name,
jobtype="match_all"
)
return metrics
except:
print(Path(sample.file_name).stem)
traceback.print_exc()

table = None
picked_str = "Failed to parse"

metrics = tableMatching(correct_answer, table, index=sample.index, compare_fields=sample.compare_fields,
record=False, file_name=sample.file_name)
record_match(
prompt=prompt,
correct=(metrics["recall_field"] == 1.0 and metrics["recall_index"] == 1.0 and metrics["recall_value"] == 1.0),
expected=correct_str,
picked=picked_str,
file_name=sample.file_name,
jobtype="match_all"
)
return metrics
return metrics
metrics = {"recall_field": 0.0, "recall_index": 0.0, "recall_value": 0.0, "recall_value_strict": 0.0,
"accuracy_value": 0.0, "accuracy_value_strict": 0.0, "recall_SMILES": 0.0}
return metrics

def run(self, recorder: RecorderBase):
raw_samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix())
Expand Down
7 changes: 0 additions & 7 deletions evals/elsuite/temp.csv

This file was deleted.

25 changes: 12 additions & 13 deletions evals/elsuite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ def is_float(str):
return True
elif fuzzy_normalize_value(a) == fuzzy_normalize_value(b):
return True
# elif ((a[-2:] in unit_str or a[-1] in unit_str or a.split()[-1] in unit_str) and
# (b[-2:] in unit_str or b[-1] in unit_str or b.split()[-1] in unit_str)):
# a = standardize_unit(a)
# b = standardize_unit(b)
# return a == b
elif ((a[-2:] in unit_str or a[-1] in unit_str or a.split()[-1] in unit_str) and
(b[-2:] in unit_str or b[-1] in unit_str or b.split()[-1] in unit_str)):
a = standardize_unit(a)
b = standardize_unit(b)
return a == b
elif a.lower() in nan_str and b.lower() in nan_str:
return True
if ((a.lower().startswith(b.lower()) or a.lower().endswith(b.lower())) or
Expand Down Expand Up @@ -271,9 +271,9 @@ def fuzzy_normalize_name(s):
s = re.sub(r'[^\w\s.\-\(\)]', '', s)
if s in synonyms:
s = synonyms[s]
if "+" in s:
s = s.replace("+","")

if "°" in s:
s = s.replace("°","")

# 分割字符串为单词列表
words = s.split()
Expand Down Expand Up @@ -348,7 +348,7 @@ def match_indices(ind0, ind1, threshold=0.9) -> dict:
Match the indices of two dataframes.
"""
renames = {}
name2query = lambda name: name if type(name) != tuple else name[0] if name[1] == "" else name[1]
name2query = lambda name: name if type(name) != tuple else name[0] if len(name)==1 or name[1] == "" else name[1]
similarities = np.array(np.ones([len(ind0) + 15, len(ind1) + 15]), dtype=np.float64)
querys0 = [name2query(name) for name in ind0]
querys1 = [name2query(name) for name in ind1]
Expand Down Expand Up @@ -423,12 +423,11 @@ def match_indices(ind0, ind1, threshold=0.9) -> dict:
for idx in df_ref.index:
_total_matching = 1.0
for col in compare_fields_:
gt = df_ref.loc[idx, col]
gt = str(gt[0]) if type(gt) == pd.Series else str(gt)
try:
p = df_prompt.loc[idx, col]
p = str(p[0]) if type(p) == pd.Series else str(p)
gt = str(df_ref.loc[idx, col])
p = str(df_prompt.loc[idx, col])
except:
gt = 'error'
p = 'not found'

_is_matching = fuzzy_compare_name(gt, p, compare_value=True) if col != "SMILES" else compare_molecule(gt, p)
Expand Down

This file was deleted.

4 changes: 2 additions & 2 deletions evals/registry/data/01_alloycomposition/composition2.jsonl
Git LFS file not shown