Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Selfcheckgpt evaluation to tasks #1080

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TaskConfig(dict):
metric_list: list = None
output_type: str = "generate_until"
generation_kwargs: dict = None
generation_kwargs_sampling: dict = None
generation_kwargs_sampling_number: int = 0
repeats: int = 1
filter_list: Union[str, list] = None
should_decontaminate: bool = False
Expand Down Expand Up @@ -999,7 +1001,33 @@ def construct_requests(
return request_list

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, self.config.generation_kwargs)
if hasattr(self.config, 'generation_kwargs_sampling'):
arguments = (ctx, self.config.generation_kwargs)
request_list = [
Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs
),
]

sampling_arguments = (ctx, self.config.generation_kwargs_sampling)
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
request_list.extend([
Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=sampling_arguments,
idx=idx,
**kwargs
)
for idx in range(1, self.config.generation_kwargs_sampling_number+1)
]
)
return request_list
else:
arguments = (ctx, self.config.generation_kwargs)

return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
Expand Down
201 changes: 201 additions & 0 deletions lm_eval/tasks/selfcheckgpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
```
pip install spacy
pip install selfcheckgpt
python -m spacy download en
```

SelfCheckGPT
=====================================================
- Project page for our paper "[SelfCheckGPT: Zero-Resource Black-Box Hallucination Detection for Generative Large Language Models](https://arxiv.org/abs/2303.08896)"
- We investigated several variants of the selfcheck approach: BERTScore, Question-Answering, n-gram, NLI, and LLM-Prompting.
- [18/07/2023] SelfCheck-NLI is added. Additional experiments show that using an entailment classifier (e.g., DeBERTa-v3 fine-tuned on MNLI) performs well. SelfCheck-NLI method requires considerably less computation than SelfCheck-Prompt.
- [11/08/2023] Slides from ML Collective Talk: [Link to Slides](https://drive.google.com/file/d/13LUBPUm4y1nlKigZxXHn7Cl2lw5KuGbc/view).
- [11/10/2023] The paper is accepted and to appear at EMNLP 2023.

![](demo/selfcheck_qa_prompt.png)
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved

## Code/Package

### Installation

pip install selfcheckgpt

### SelfCheckGPT Usage: BERTScore, QA, n-gram

There are three variants of SelfCheck scores in this package as described in the paper: `SelfCheckBERTScore()`, `SelfCheckMQAG()`, `SelfCheckNgram()`. All of the variants have `predict()` which will output the sentence-level scores w.r.t. sampled passages. You can use packages such as spacy to split passage into sentences. For reproducibility, you can set `torch.manual_seed` before calling this function. See more details in Jupyter Notebook [```demo/SelfCheck_demo1.ipynb```](demo/SelfCheck_demo1.ipynb)

```python
# Include necessary packages (torch, spacy, ...)
from selfcheckgpt.modeling_selfcheck import SelfCheckMQAG, SelfCheckBERTScore, SelfCheckNgram
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
selfcheck_mqag = SelfCheckMQAG(device=device) # set device to 'cuda' if GPU is available
selfcheck_bertscore = SelfCheckBERTScore(rescale_with_baseline=True)
selfcheck_ngram = SelfCheckNgram(n=1) # n=1 means Unigram, n=2 means Bigram, etc.

# LLM's text (e.g. GPT-3 response) to be evaluated at the sentence level & Split it into sentences
passage = "Michael Alan Weiner (born March 31, 1942) is an American radio host. He is the host of The Savage Nation."
sentences = [sent.text.strip() for sent in nlp(passage).sents] # spacy sentence tokenization
print(sentences)
['Michael Alan Weiner (born March 31, 1942) is an American radio host.', 'He is the host of The Savage Nation.']

# Other samples generated by the same LLM to perform self-check for consistency
sample1 = "Michael Alan Weiner (born March 31, 1942) is an American radio host. He is the host of The Savage Country."
sample2 = "Michael Alan Weiner (born January 13, 1960) is a Canadian radio host. He works at The New York Times."
sample3 = "Michael Alan Weiner (born March 31, 1942) is an American radio host. He obtained his PhD from MIT."

# --------------------------------------------------------------------------------------------------------------- #
# SelfCheck-MQAG: Score for each sentence where value is in [0.0, 1.0] and high value means non-factual
# Additional params for each scoring_method:
# -> counting: AT (answerability threshold, i.e. questions with answerability_score < AT are rejected)
# -> bayes: AT, beta1, beta2
# -> bayes_with_alpha: beta1, beta2
sent_scores_mqag = selfcheck_mqag.predict(
sentences = sentences, # list of sentences
passage = passage, # passage (before sentence-split)
sampled_passages = [sample1, sample2, sample3], # list of sampled passages
num_questions_per_sent = 5, # number of questions to be drawn
scoring_method = 'bayes_with_alpha', # options = 'counting', 'bayes', 'bayes_with_alpha'
beta1 = 0.8, beta2 = 0.8, # additional params depending on scoring_method
)
print(sent_scores_mqag)
# [0.30990949 0.42376232]

# --------------------------------------------------------------------------------------------------------------- #
# SelfCheck-BERTScore: Score for each sentence where value is in [0.0, 1.0] and high value means non-factual
sent_scores_bertscore = selfcheck_bertscore.predict(
sentences = sentences, # list of sentences
sampled_passages = [sample1, sample2, sample3], # list of sampled passages
)
print(sent_scores_bertscore)
# [0.0695562 0.45590915]

# --------------------------------------------------------------------------------------------------------------- #
# SelfCheck-Ngram: Score at sentence- and document-level where value is in [0.0, +inf) and high value means non-factual
# as opposed to SelfCheck-MQAG and SelfCheck-BERTScore, SelfCheck-Ngram's score is not bounded
sent_scores_ngram = selfcheck_ngram.predict(
sentences = sentences,
passage = passage,
sampled_passages = [sample1, sample2, sample3],
)
print(sent_scores_ngram)
# {'sent_level': { # sentence-level score similar to MQAG and BERTScore variant
# 'avg_neg_logprob': [3.184312, 3.279774],
# 'max_neg_logprob': [3.476098, 4.574710]
# },
# 'doc_level': { # document-level score such that avg_neg_logprob is computed over all tokens
# 'avg_neg_logprob': 3.218678904916201,
# 'avg_max_neg_logprob': 4.025404834169327
# }
# }
```

### SelfCheckGPT Usage: NLI

Entailment (or Contradiction) score with input being the sentence and a sampled passage can be used as the selfcheck score. We use DeBERTa-v3-large fine-tuned to Multi-NLI, and we normalize the probability of "entailment" or "contradiction" classes, and take Prob(contradiction) as the score.

```python
from selfcheckgpt.modeling_selfcheck import SelfCheckNLI
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
selfcheck_nli = SelfCheckNLI(device=device) # set device to 'cuda' if GPU is available

sent_scores_nli = selfcheck_nli.predict(
sentences = sentences, # list of sentences
sampled_passages = [sample1, sample2, sample3], # list of sampled passages
)
print(sent_scores_nli)
# [0.334014 0.975106 ] -- based on the example above
```

### SelfCheckGPT Usage: (LLM) Prompt

In addition, we've tried using LLMs to assess information consistency in a zero-shot setup. We query a LLM to assess whether the i-th sentence is supported by the sample (as the context) using the following prompt.

```
Context: {}
Sentence: {}
Is the sentence supported by the context above?
Answer Yes or No:
```

Initial investigation showed that GPT-3 (text-davinci-003) will output either Yes or No 98% of the time, while any remaining outputs can be set to N/A. The output is converted to score: Yes -> 0.0, No -> 1.0, N/A -> 0.5. The inconsistency score is then calculated by averaging.

## Dataset
The `wiki_bio_gpt3_hallucination` dataset currently consists of 238 annotated passages (`v3`). You can find more information in the paper or our data card on HuggingFace: https://huggingface.co/datasets/potsawee/wiki_bio_gpt3_hallucination. To use this dataset, you can either load it through HuggingFace dataset API, or download it directly from below in the JSON format.

### Update
We've annotated GPT-3 wikibio passages further, and now the dataset consists of 238 annotated passages. Here is [the link](https://drive.google.com/file/d/1N3_ZQmr9yBbsOP2JCpgiea9oiNIu78Xw/view?usp=sharing) for the IDs of the first 65 passages in the `v1`.

### Option1: HuggingFace

```python
from datasets import load_dataset
dataset = load_dataset("potsawee/wiki_bio_gpt3_hallucination")
```

### Option2: Manual Download
Download from our [Google Drive](https://drive.google.com/file/d/1AyQ7u9nYlZgUZLm5JBDx6cFFWB__EsNv/view?usp=share_link), then you can load it in python:

```python
import json
with open("dataset.json", "r") as f:
content = f.read()
dataset = json.loads(content)
```

Each instance consists of:
- `gpt3_text`: GPT-3 generated passage
- `wiki_bio_text`: Actual Wikipedia passage (first paragraph)
- `gpt3_sentences`: `gpt3_text` split into sentences using `spacy`
- `annotation`: human annotation at the sentence level
- `wiki_bio_test_idx`: ID of the concept/individual from the original wikibio dataset (testset)
- `gpt3_text_samples`: list of sampled passages (do_sample = True & temperature = 1.0)

## Experiments

### Probability-based baselines (e.g. GPT-3's probabilities)

As described in our paper, probabities (and generation entropies) of the generative LLM can be used to measure its confidence. Check our example/implementation of this approach in [```demo/experiments/probability-based-baselines.ipynb```](demo/experiments/probability-based-baselines.ipynb)

### Experimental Results
- Full details can be found in our paper.
- Note that our new results show that LLMs such as GPT-3 (text-davinci-003) or ChatGPT (gpt-3.5-turbo) are good at text inconsistency assessment. Based on this finding, we try **SelfCheckGPT-Prompt** where each sentence (to be evaluated) is compared against each and every sampled_passage by prompting ChatGPT. SelfCheckGPT-Prompt is the best-performing method.

Results on the `wiki_bio_gpt3_hallucination` dataset.

| Method | NonFact (AUC-PR) | Factual (AUC-PR) | Ranking (PCC) |
|----------------------|:------------------:|:------------------:|:-----------------:|
| Random Guessing | 72.96 | 27.04 | - |
| GPT-3 Avg(-logP) | 83.21 | 53.97 | 57.04 |
| SelfCheck-BERTScore | 81.96 | 44.23 | 58.18 |
| SelfCheck-QA | 84.26 | 48.14 | 61.07 |
| SelfCheck-Unigram | 85.63 | 58.47 | 64.71 |
| SelfCheck-NLI | 92.50 | 66.08 | 74.14 |
| **SelfCheck-Prompt** | **93.42** | **67.09** | **78.32** |

## Miscellaneous
[MQAG (Multiple-choice Question Answering and Generation)](https://arxiv.org/abs/2301.12307) was proposed in our previous work. Our MQAG implementation is included in this package, which can be used to: (1) generate multiple-choice questions, (2) answer multiple-choice questions, (3) obtain MQAG score.

### MQAG Usage

```python
from selfcheckgpt.modeling_mqag import MQAG
mqag_model = MQAG()
```

It has three main functions: `generate()`, `answer()`, `score()`. We show an example usage in [```demo/MQAG_demo1.ipynb```](demo/MQAG_demo1.ipynb)

## Acknowledgements
This work is supported by Cambridge University Press & Assessment (CUP&A), a department of The Chancellor, Masters, and Scholars of the University of Cambridge, and the Cambridge Commonwealth, European & International Trust.

## Citation

```
@misc{manakul2023selfcheckgpt,
title={SelfCheckGPT: Zero-Resource Black-Box Hallucination Detection for Generative Large Language Models},
author={Potsawee Manakul and Adian Liusie and Mark J. F. Gales},
year={2023},
eprint={2303.08896},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
28 changes: 28 additions & 0 deletions lm_eval/tasks/selfcheckgpt/selfcheckgpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
task: selfcheckgpt
dataset_path: potsawee/wiki_bio_gpt3_hallucination
output_type: generate_until
training_split: null
validation_split: evaluation
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
generation_kwargs:
# until:
# - "\n"
temperature: 0.0
do_sample: false
generation_kwargs_sampling_number: 5
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
generation_kwargs_sampling:
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
# until:
# - "\n"
temperature: 1.0
do_sample: false
metric_list:
- metric: avg
aggregation: mean
higher_is_better: true
- metric: max
aggregation: mean
higher_is_better: true
metadata:
- version: 2.0
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved
80 changes: 80 additions & 0 deletions lm_eval/tasks/selfcheckgpt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import spacy
import torch
from selfcheckgpt.modeling_selfcheck import SelfCheckMQAG, SelfCheckNLI, SelfCheckBERTScore, SelfCheckNgram

# pip install spacy
# pip install selfcheckgpt
# python -m spacy download en

selfcheckgpt_type = os.environ.get('SELFCHECKGPTTYPE', 'SelfCheckNgram')
selfcheckgpt_device = os.environ.get('SELFCHECKGPTDEVICE', 'cpu')
selfcheckgpt_nlp = spacy.load("en_core_web_sm")

if selfcheckgpt_type == 'SelfCheckNgram':
selfcheckgpt = SelfCheckNgram(n=1)
elif selfcheckgpt_type == 'SelfCheckBERTScore':
selfcheckgpt = SelfCheckBERTScore(rescale_with_baseline=True)
elif selfcheckgpt_type == 'SelfCheckMQAG':
selfcheckgpt = SelfCheckMQAG(device=selfcheckgpt_device)
elif selfcheckgpt_type == 'SelfCheckNLI':
selfcheckgpt = SelfCheckNLI(device=selfcheckgpt_device)
else:
raise ValueError(f"Wrong SELFCHECKGPTTYPE environment variable: {selfcheckgpt_type}")
StellaAthena marked this conversation as resolved.
Show resolved Hide resolved

print("Load selfcheckgpt successfully")


def doc_to_text(doc):
doc_text = doc["wiki_bio_text"]
doc_text = doc_text.split()
doc_text = " ".join(doc_text[:5])
doc_text = f"Please generating a Wikipedia passage starting with: {doc_text}\n"
return doc_text


def doc_to_target(doc):
answer = doc['wiki_bio_text']
return answer

def process_results(doc, results, threshold=0.6):

response_temperature_0 = results[0]
other_responses = results[1:]
passage = doc_to_target(doc)

sentences = selfcheckgpt_nlp(response_temperature_0)
sentences = [sent.text.strip() for sent in sentences.sents]
if selfcheckgpt_type == 'SelfCheckNgram':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
passage = passage,
sampled_passages = other_responses,
)
return {"avg": selfcheckgpt_scores["doc_level"]["avg_neg_logprob"],
"max": selfcheckgpt_scores["doc_level"]["avg_max_neg_logprob"]}

elif selfcheckgpt_type == 'SelfCheckBERTScore':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
sampled_passages = other_responses,
)
elif selfcheckgpt_type == 'SelfCheckMQAG':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
sampled_passages = other_responses,
)
elif selfcheckgpt_type == 'SelfCheckNLI':
selfcheckgpt_scores = selfcheckgpt.predict(
sentences = sentences,
passage = passage,
sampled_passages = other_responses,
num_questions_per_sent = 5, # number of questions to be drawn
scoring_method = 'bayes_with_alpha', # options = 'counting', 'bayes', 'bayes_with_alpha'
beta1 = 0.8, beta2 = 0.8, # additional params depending on scoring_method
)

selfcheckgpt_scores_avg = sum(selfcheckgpt_scores) / len(selfcheckgpt_scores) if len(selfcheckgpt_scores) > 0 else 0
selfcheckgpt_scores_max = max(selfcheckgpt_scores)

return {'avg': selfcheckgpt_scores_avg, 'max': selfcheckgpt_scores_max}