Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wangcx #21

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions evals/elsuite/choice_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os
from pathlib import Path
from typing import Any

import oss2
from oss2.credentials import EnvironmentVariableCredentialsProvider

import evals
import evals.metrics
from evals.api import CompletionFn
from evals.prompt.base import is_chat_prompt


def init_oss():
"""
Initialize OSS client.
"""
# Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables.
auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider())

# 设置 Endpoint
endpoint = 'https://oss-cn-beijing.aliyuncs.com'

# 设置 Bucket
bucket_name = 'dp-filetrans-bj'
bucket = oss2.Bucket(auth, endpoint, bucket_name)

return bucket


def get_rag_dataset(samples_jsonl: str) -> list[dict]:
bucket = init_oss()
raw_samples = evals.get_jsonl(samples_jsonl)

for raw_sample in raw_samples:
for ftype in ["", "answer"]:
if f"{ftype}file_name" not in raw_sample and f"{ftype}file_link" not in raw_sample:
continue
if f"{ftype}file_name" in raw_sample:
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"])
raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file

exists = bucket.object_exists(oss_file)
if exists:
print(f"文件 {oss_file} 已存在于 OSS 中。")
else:
# 上传文件
bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"])
print(f"文件 {oss_file} 已上传到 OSS。")
if f"{ftype}file_link" in raw_sample:
local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else \
os.path.basename(raw_sample[f"{ftype}file_link"])
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"])
if not os.path.exists(local_file):
if bucket.object_exists(oss_file):
# 从 OSS 下载文件
Path(local_file).parent.mkdir(parents=True, exist_ok=True)
bucket.get_object_to_file(oss_file, local_file)
print(f"文件 {oss_file} 已下载到本地。")
return raw_samples


class RAGMatch(evals.Eval):
def __init__(
self,
completion_fns: list[CompletionFn],
samples_jsonl: str,
*args,
max_tokens: int = 500,
num_few_shot: int = 0,
few_shot_jsonl: str = None,
**kwargs,
):
super().__init__(completion_fns, *args, **kwargs)
assert len(completion_fns) == 1, "Match only supports one completion fn"
self.max_tokens = max_tokens
self.samples_jsonl = samples_jsonl
self.num_few_shot = num_few_shot
if self.num_few_shot > 0:
assert few_shot_jsonl is not None, "few shot requires few shot sample dataset"
self.few_shot_jsonl = few_shot_jsonl
self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl))

def eval_sample(self, sample: Any, *_):
assert isinstance(sample, dict), "sample must be a dict"
assert "input" in sample, "sample must have an 'input' key"
assert "ideal" in sample, "sample must have an 'ideal' key"
assert isinstance(sample["ideal"], str) or isinstance(
sample["ideal"], list
), "sample['ideal'] must be a string or list of strings"

prompt = sample["input"]
if self.num_few_shot > 0:
assert is_chat_prompt(sample["input"]), "few shot requires chat prompt"
prompt = sample["input"][:-1]
for s in self.few_shot[: self.num_few_shot]:
prompt += s["sample"]
prompt += sample["input"][-1:]

result = self.completion_fn(
prompt=prompt,
temperature=0.0,
**{k: v for k, v in sample.items() if k not in ["input", "ideal"]}
)
sampled = result.get_completions()[0]

extras = {}
if hasattr(result, "extras"):
if "extracted_answer" in result.extras:
sampled = result.extras["extracted_answer"].rstrip(".")
extras = result.extras
print(sampled)
sampled = sampled.split("\n")
for i in range(len(sampled)-1, -1, -1):
if i == 0:
sampled = sampled[0]
elif sampled[i] == "":
continue
else:
sampled = sampled[i]
break
for i in ["a)", "b)", "c)", "d)"]:
if i in sample["ideal"] and i in sampled:
continue
elif i not in sample["ideal"] and i not in sampled:
continue
else:
sampled = ""
break
if sampled != "":
sampled = sample["ideal"]
print("compare", sampled, sample["ideal"])
return evals.record_and_check_match(
prompt=prompt,
sampled=sampled,
expected=sample["ideal"],
file_name=sample["file_name"],
**extras
)

def run(self, recorder):
samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix())
self.eval_all_samples(recorder, samples)
events = recorder.get_events("match")
return {
"accuracy": evals.metrics.get_accuracy(events),
"boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events),
}
1 change: 0 additions & 1 deletion evals/elsuite/rag_table_extract_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def eval_sample(self, sample, rng):
result = self.completion_fn(
prompt=prompt,
temperature=0.0,
max_tokens=5,
file_name=sample.file_name,
file_link=sample.file_link
)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def tableMatching(df_ref, df_prompt, index='Compound', compare_fields=[], record
return {"recall_field": 0.0, "recall_index": 0.0, "recall_value": 0.0, "recall_value_strict": 0.0,
"accuracy_value": 0.0, "accuracy_value_strict": 0.0, "recall_SMILES": 0.0}
metrics = {}
index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate"]
index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate","AlloyName"]

if index not in [None, ""]:
df_ref[index] = df_ref[index].astype(str)
Expand Down
3 changes: 3 additions & 0 deletions evals/registry/data/01_alloychart/samples.jsonl
Git LFS file not shown
4 changes: 2 additions & 2 deletions evals/registry/data/01_alloycomposition/composition.jsonl
Git LFS file not shown
4 changes: 2 additions & 2 deletions evals/registry/data/01_alloynum/alloy_number.jsonl
Git LFS file not shown
4 changes: 2 additions & 2 deletions evals/registry/data/01_alloysort/sort.jsonl
Git LFS file not shown