Skip to content

Commit

Permalink
fixing evaluation for SQuAD dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed May 10, 2024
1 parent 75b94f6 commit a39e671
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 130 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# MacOS
.DS_Store
*/.DS_Store
7 changes: 4 additions & 3 deletions datasets/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@

ToDo:
- at least one should be financial or legal and raw data needs to be in structured pdfs
- at least one should be about support/help centre
- at least one should be about support/help centre
- there should be one that has been used in other benchmarks (maybe based on wikipedia)
- they should all have a set of labels so that we can get performance metrics from them


## SQuAD

- domain: wikipedia
- labels: answer, documents
- data type: text files
- source: https://huggingface.co/datasets/squad
- paper: [SQuAD: 100,000+ Questions for Machine Comprehension of Text](https://arxiv.org/abs/1606.05250)
- website: [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)
- evaluation:
- evaluation:[ContextRelevance](), [Faithfulness](), [Semantic Answer Similarity](), [DocumentMRR](), [DocumentMAP](), [DocumentRecall]()


## ARAGOG
- domain: a collection of AI/LLM-ArXiv papers
- labels: answer
- data type: PDF files
- source: https://github.com/predlico/ARAGOG
- paper: [ARAGOG: Advanced RAG Output Grading](https://arxiv.org/pdf/2404.01037)
- evaluation: [ContextRelevance](), [Faithfulness](), [Semantic Answer Similarity]()
8 changes: 4 additions & 4 deletions arago_evaluation.py → evaluations/arago_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from architectures.hyde_rag import rag_with_hyde

embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
files_path = "datasets/ARAGOG/papers_for_questions"
files_path = "../datasets/ARAGOG/papers_for_questions"


def indexing():
Expand All @@ -39,7 +39,7 @@ def indexing():


def read_question_answers():
with open("datasets/ARAGOG/eval_questions.json", "r") as f:
with open("../datasets/ARAGOG/eval_questions.json", "r") as f:
data = json.load(f)
questions = data["questions"]
answers = data["ground_truths"]
Expand Down Expand Up @@ -71,7 +71,7 @@ def run_basic_rag(doc_store, sample_questions, sample_answers):
"sas": sas.run(predicted_answers, sample_answers),
'predicted_answers': predicted_answers,
}
inputs = {'questions': sample_questions}
inputs = {'questions': sample_questions, "true_answers": sample_answers, "predicted_answers": predicted_answers}

return EvaluationRunResult(run_name="basic_rag", inputs=inputs, results=results)

Expand All @@ -98,7 +98,7 @@ def run_hyde_rag(doc_store, sample_questions, sample_answers):
"faithfulness": faithfulness.run(sample_questions, retrieved_contexts, predicted_answers),
"sas": sas.run(predicted_answers, sample_answers)
}
inputs = {'questions': sample_questions}
inputs = {'questions': sample_questions, "true_answers": sample_answers, "predicted_answers": predicted_answers}

return EvaluationRunResult(run_name="hyde_rag", inputs=inputs, results=results)

Expand Down
Empty file added evaluations/evaluation.md
Empty file.
204 changes: 204 additions & 0 deletions evaluations/squad_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import json
import os
import random
from typing import List

from haystack import Pipeline, Document
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.evaluators import (
DocumentMRREvaluator,
DocumentMAPEvaluator,
DocumentRecallEvaluator,
FaithfulnessEvaluator,
SASEvaluator
)
from haystack.components.evaluators.document_recall import RecallMode
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack.evaluation import EvaluationRunResult
from tqdm import tqdm

from architectures.basic_rag import basic_rag
from architectures.hyde_rag import rag_with_hyde

embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
base_path = "../datasets/SQuAD-2.0/transformed_squad/"


def load_transformed_squad():
with open(base_path+"questions.jsonl", "r") as f:
questions = [json.loads(x) for x in f.readlines()]
for idx, question in enumerate(questions):
question["query_id"] = f"query_{idx}"

def create_document(text: str, name: str):
return Document(content=text, meta={"name": name})

# walk through the files in the directory and transform each line of each text file into a Document
documents = []
for root, dirs, files in os.walk(base_path):
for article in files:
with open(f"{root}/{article}", "r") as f:
raw_texts = f.read().split("\n")
for text in raw_texts:
documents.append(create_document(text, article.replace(".txt", "")))

return questions, documents


def indexing(documents: List[Document]):
document_store = InMemoryDocumentStore()
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
doc_embedder = SentenceTransformersDocumentEmbedder(model=embedding_model)
ingestion_pipe = Pipeline()
ingestion_pipe.add_component(instance=doc_embedder, name="doc_embedder")
ingestion_pipe.add_component(instance=doc_writer, name="doc_writer")
ingestion_pipe.connect("doc_embedder.documents", "doc_writer.documents")
ingestion_pipe.run({"doc_embedder": {"documents": documents}})

return document_store


def run_basic_rag(doc_store, samples):

rag = basic_rag(document_store=doc_store, embedding_model=embedding_model, top_k=3)

# ground truth data
questions = []
ground_truth_docs = []
ground_truth_answers = []

# predicted data
retrieved_docs = []
predicted_contexts = []
predicted_answers = []

for sample in tqdm(samples):
q = sample["question"]
answer = sample["answers"]["text"]
ground_truth_documents = [doc for doc in doc_store.storage.values() if doc.meta["name"] == sample["document"]]
response = rag.run(
data={"query_embedder": {"text": q}, "prompt_builder": {"question": q}, "answer_builder": {"query": q}}
)

# gather ground truth data
ground_truth_docs.append(ground_truth_documents)
ground_truth_answers.append(answer[0])
questions.append(q)

# gather response data
retrieved_docs.append(response["answer_builder"]["answers"][0].documents)
predicted_contexts.append([doc.content for doc in response["answer_builder"]["answers"][0].documents])
predicted_answers.append(response["answer_builder"]["answers"][0].data)

eval_pipeline = Pipeline()
eval_pipeline.add_component("doc_mrr", DocumentMRREvaluator())
eval_pipeline.add_component("doc_map", DocumentMAPEvaluator())
eval_pipeline.add_component("doc_recall_single_hit", DocumentRecallEvaluator(mode=RecallMode.SINGLE_HIT))
eval_pipeline.add_component("doc_recall_multi_hit", DocumentRecallEvaluator(mode=RecallMode.MULTI_HIT))
eval_pipeline.add_component("faithfulness", FaithfulnessEvaluator())
eval_pipeline.add_component("sas", SASEvaluator(model=embedding_model))

eval_pipeline_results = eval_pipeline.run(
{
"doc_mrr": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs},
"faithfulness": {"questions": questions, "contexts": predicted_contexts, "predicted_answers": predicted_answers},
"sas": {"predicted_answers": predicted_answers, "ground_truth_answers": ground_truth_answers},
"doc_map": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs},
"doc_recall_single_hit": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs},
"doc_recall_multi_hit": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs}
}
)

results = {
"doc_mrr": eval_pipeline_results['doc_mrr'],
"faithfulness": eval_pipeline_results['faithfulness'],
"sas": eval_pipeline_results['sas'],
"doc_map": eval_pipeline_results['doc_map'],
"doc_recall_single_hit": eval_pipeline_results['doc_recall_single_hit'],
"doc_recall_multi_hit": eval_pipeline_results['doc_recall_multi_hit']
}

inputs = {'questions': questions, 'true_answers': ground_truth_answers, 'predicted_answers': predicted_answers}

return EvaluationRunResult(run_name="basic_rag", inputs=inputs, results=results)


def run_hyde_rag(doc_store, samples):

hyde_rag = rag_with_hyde(document_store=doc_store, embedding_model=embedding_model, top_k=3)

# ground truth data
questions = []
ground_truth_docs = []
ground_truth_answers = []

# predicted data
retrieved_docs = []
predicted_contexts = []
predicted_answers = []

for sample in tqdm(samples):
q = sample["question"]
answer = sample["answers"]["text"]
ground_truth_documents = [doc for doc in doc_store.storage.values() if doc.meta["name"] == sample["document"]]
response = hyde_rag.run(
data={"hyde": {"query": q}, "prompt_builder": {"question": q}, "answer_builder": {"query": q}}
)

# gather ground truth data
ground_truth_docs.append(ground_truth_documents)
ground_truth_answers.append(answer[0])
questions.append(q)

# gather response data
retrieved_docs.append(response["answer_builder"]["answers"][0].documents)
predicted_contexts.append([doc.content for doc in response["answer_builder"]["answers"][0].documents])
predicted_answers.append(response["answer_builder"]["answers"][0].data)

eval_pipeline = Pipeline()
eval_pipeline.add_component("doc_mrr", DocumentMRREvaluator())
eval_pipeline.add_component("doc_map", DocumentMAPEvaluator())
eval_pipeline.add_component("doc_recall_single_hit", DocumentRecallEvaluator(mode=RecallMode.SINGLE_HIT))
eval_pipeline.add_component("doc_recall_multi_hit", DocumentRecallEvaluator(mode=RecallMode.MULTI_HIT))
eval_pipeline.add_component("faithfulness", FaithfulnessEvaluator())
eval_pipeline.add_component("sas", SASEvaluator(model=embedding_model))

eval_pipeline_results = eval_pipeline.run(
{
"doc_mrr": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs},
"faithfulness": {"questions": questions, "contexts": predicted_contexts, "predicted_answers": predicted_answers},
"sas": {"predicted_answers": predicted_answers, "ground_truth_answers": ground_truth_answers},
"doc_map": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs},
"doc_recall_single_hit": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs},
"doc_recall_multi_hit": {"ground_truth_documents": ground_truth_docs, "retrieved_documents": retrieved_docs}
}
)

results = {
"doc_mrr": eval_pipeline_results['doc_mrr'],
"faithfulness": eval_pipeline_results['faithfulness'],
"sas": eval_pipeline_results['sas'],
"doc_map": eval_pipeline_results['doc_map'],
"doc_recall_single_hit": eval_pipeline_results['doc_recall_single_hit'],
"doc_recall_multi_hit": eval_pipeline_results['doc_recall_multi_hit']
}

inputs = {'questions': questions, 'true_answers': ground_truth_answers, 'predicted_answers': predicted_answers}

return EvaluationRunResult(run_name="hyde_rag", inputs=inputs, results=results)


def main():

all_questions, documents = load_transformed_squad()
doc_store = indexing(documents)

limit = 10
samples = random.sample(all_questions, limit)

basic_rag_results = run_basic_rag(doc_store, samples)
hyde_rag_results = run_hyde_rag(doc_store, samples)

comparative_df = basic_rag_results.comparative_individual_scores_report(hyde_rag_results)

0 comments on commit a39e671

Please sign in to comment.