Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Selfcheckgpt evaluation to tasks #1080

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TaskConfig(dict):
metric_list: list = None
output_type: str = "generate_until"
generation_kwargs: dict = None
generation_kwargs_sampling: dict = None
generation_kwargs_sampling_number: int = 0
repeats: int = 1
filter_list: Union[str, list] = None
should_decontaminate: bool = False
Expand Down Expand Up @@ -999,7 +1001,33 @@ def construct_requests(
return request_list

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, self.config.generation_kwargs)
if hasattr(self.config, 'generation_kwargs_sampling'):
arguments = (ctx, self.config.generation_kwargs)
request_list = [
Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs
),
]

sampling_arguments = (ctx, self.config.generation_kwargs_sampling)
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
request_list.extend([
Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=sampling_arguments,
idx=idx,
**kwargs
)
for idx in range(1, self.config.generation_kwargs_sampling_number+1)
]
)
return request_list
else:
arguments = (ctx, self.config.generation_kwargs)

return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
Expand Down
20 changes: 20 additions & 0 deletions lm_eval/tasks/selfcheckgpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
```
pip install spacy
pip install selfcheckgpt
python -m spacy download en
```

SelfCheckGPT

## Citation

```
@misc{manakul2023selfcheckgpt,
title={SelfCheckGPT: Zero-Resource Black-Box Hallucination Detection for Generative Large Language Models},
author={Potsawee Manakul and Adian Liusie and Mark J. F. Gales},
year={2023},
eprint={2303.08896},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
28 changes: 28 additions & 0 deletions lm_eval/tasks/selfcheckgpt/selfcheckgpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
task: selfcheckgpt
dataset_path: potsawee/wiki_bio_gpt3_hallucination
output_type: generate_until
training_split: null
validation_split: evaluation
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
generation_kwargs:
# until:
# - "\n"
temperature: 0.0
do_sample: false
generation_kwargs_sampling_number: 5
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
generation_kwargs_sampling:
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
# until:
# - "\n"
temperature: 1.0
do_sample: false
metric_list:
- metric: avg
aggregation: mean
higher_is_better: true
- metric: max
aggregation: mean
higher_is_better: true
metadata:
- version: 2.0
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
80 changes: 80 additions & 0 deletions lm_eval/tasks/selfcheckgpt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import spacy
import torch
from selfcheckgpt.modeling_selfcheck import SelfCheckMQAG, SelfCheckNLI, SelfCheckBERTScore, SelfCheckNgram

# pip install spacy
# pip install selfcheckgpt
# python -m spacy download en

selfcheckgpt_type = os.environ.get('SELFCHECKGPTTYPE', 'SelfCheckNgram')
selfcheckgpt_device = os.environ.get('SELFCHECKGPTDEVICE', 'cpu')
selfcheckgpt_nlp = spacy.load("en_core_web_sm")

if selfcheckgpt_type == 'SelfCheckNgram':
selfcheckgpt = SelfCheckNgram(n=1)
elif selfcheckgpt_type == 'SelfCheckBERTScore':
selfcheckgpt = SelfCheckBERTScore(rescale_with_baseline=True)
elif selfcheckgpt_type == 'SelfCheckMQAG':
selfcheckgpt = SelfCheckMQAG(device=selfcheckgpt_device)
elif selfcheckgpt_type == 'SelfCheckNLI':
selfcheckgpt = SelfCheckNLI(device=selfcheckgpt_device)
else:
raise ValueError(f"Wrong SELFCHECKGPTTYPE environment variable: {selfcheckgpt_type}")
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved

print("Load selfcheckgpt successfully")


def doc_to_text(doc):
doc_text = doc["wiki_bio_text"]
doc_text = doc_text.split()
doc_text = " ".join(doc_text[:5])
doc_text = f"Please generating a Wikipedia passage starting with: {doc_text}\n"
return doc_text


def doc_to_target(doc):
answer = doc['wiki_bio_text']
return answer

def process_results(doc, results, threshold=0.6):

response_temperature_0 = results[0]
other_responses = results[1:]
passage = doc_to_target(doc)

sentences = selfcheckgpt_nlp(response_temperature_0)
sentences = [sent.text.strip() for sent in sentences.sents]
if selfcheckgpt_type == 'SelfCheckNgram':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
passage = passage,
sampled_passages = other_responses,
)
return {"avg": selfcheckgpt_scores["doc_level"]["avg_neg_logprob"],
"max": selfcheckgpt_scores["doc_level"]["avg_max_neg_logprob"]}

elif selfcheckgpt_type == 'SelfCheckBERTScore':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
sampled_passages = other_responses,
)
elif selfcheckgpt_type == 'SelfCheckMQAG':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
sampled_passages = other_responses,
)
elif selfcheckgpt_type == 'SelfCheckNLI':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
passage = passage,
sampled_passages = other_responses,
num_questions_per_sent = 5, # number of questions to be drawn
scoring_method = 'bayes_with_alpha', # options = 'counting', 'bayes', 'bayes_with_alpha'
beta1 = 0.8, beta2 = 0.8, # additional params depending on scoring_method
)

selfcheckgpt_scores_avg = sum(selfcheckgpt_scores) / len(selfcheckgpt_scores) if len(selfcheckgpt_scores) > 0 else 0
selfcheckgpt_scores_max = max(selfcheckgpt_scores)

return {'avg': selfcheckgpt_scores_avg, 'max': selfcheckgpt_scores_max}