-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a39e671
commit fe25ad7
Showing
7 changed files
with
211 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,18 @@ | ||
# haystack-evaluation | ||
Using Haystack to benchmark different RAG architectures over different datasets | ||
|
||
|
||
Use evaluation on the selected datasets to optimise some parameters commonly tweaked in RAG pipelines: | ||
|
||
- top_k | ||
- chunk_size | ||
- embedding model | ||
|
||
|
||
goal number 1 is to give user practical guidance on what techniques to try out on their dataset/use case | ||
|
||
goal number 2 is to show that there is not a “silver bullet” type of solution, that it depends on the dataset and use case, but that Haystack can support them all | ||
|
||
goal number 3 is to showcase advanced evaluation/experimentation API (most advanced compared to competitors) | ||
|
||
it’s not a research paper, so should not be too “academic” (i.e. not too restricted in terms of metrics or datasets to use, not meant to be peer-reviewed or submitted to an academic conference) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from copy import deepcopy | ||
from typing import List, Tuple, Dict | ||
|
||
from haystack import component, Document | ||
from more_itertools import windowed | ||
|
||
|
||
@component | ||
class CustomDocumentSplitter: | ||
|
||
def __init__( | ||
self, | ||
split_length: int = 200, | ||
split_overlap: int = 0, | ||
): | ||
|
||
self.split_by = "\n" | ||
if split_length <= 0: | ||
raise ValueError("split_length must be greater than 0.") | ||
self.split_length = split_length | ||
if split_overlap < 0: | ||
raise ValueError("split_overlap must be greater than or equal to 0.") | ||
self.split_overlap = split_overlap | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run(self, documents: List[Document]): | ||
|
||
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): | ||
raise TypeError("DocumentSplitter expects a List of Documents as input.") | ||
|
||
split_docs = [] | ||
for doc in documents: | ||
if doc.content is None: | ||
raise ValueError( | ||
f"DocumentSplitter only works with text documents but document.content for document ID {doc.id} is None." | ||
) | ||
units = self._split_into_units(doc.content) | ||
text_splits, splits_pages = self._concatenate_units(units, self.split_length, self.split_overlap) | ||
metadata = deepcopy(doc.meta) | ||
metadata["source_id"] = doc.id | ||
split_docs += self._create_docs_from_splits( | ||
text_splits=text_splits, splits_pages=splits_pages, meta=metadata | ||
) | ||
return {"documents": split_docs} | ||
|
||
def _split_into_units(self, text: str) -> List[str]: | ||
split_at = "\n" | ||
units = text.split(split_at) | ||
# Add the delimiter back to all units except the last one | ||
for i in range(len(units) - 1): | ||
units[i] += split_at | ||
return units | ||
|
||
def _concatenate_units( | ||
self, elements: List[str], split_length: int, split_overlap: int | ||
) -> Tuple[List[str], List[int]]: | ||
|
||
text_splits = [] | ||
splits_pages = [] | ||
cur_page = 1 | ||
segments = windowed(elements, n=split_length, step=split_length - split_overlap) | ||
for seg in segments: | ||
current_units = [unit for unit in seg if unit is not None] | ||
txt = "".join(current_units) | ||
if len(txt) > 0: | ||
text_splits.append(txt) | ||
splits_pages.append(cur_page) | ||
processed_units = current_units[: split_length - split_overlap] | ||
if self.split_by == "page": | ||
num_page_breaks = len(processed_units) | ||
else: | ||
num_page_breaks = sum(processed_unit.count("\f") for processed_unit in processed_units) | ||
cur_page += num_page_breaks | ||
return text_splits, splits_pages | ||
|
||
@staticmethod | ||
def _create_docs_from_splits(text_splits: List[str], splits_pages: List[int], meta: Dict) -> List[Document]: | ||
""" | ||
Creates Document objects from text splits enriching them with page number and the metadata of the original document. | ||
""" | ||
documents: List[Document] = [] | ||
|
||
for i, txt in enumerate(text_splits): | ||
meta = deepcopy(meta) | ||
doc = Document(content=txt, meta=meta) | ||
doc.meta["page_number"] = splits_pages[i] | ||
documents.append(doc) | ||
return documents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import json | ||
import os | ||
import random | ||
|
||
from haystack import Pipeline, component | ||
from haystack.components.embedders import SentenceTransformersDocumentEmbedder | ||
from haystack.document_stores.in_memory import InMemoryDocumentStore | ||
from haystack.components.converters import PyPDFToDocument | ||
from haystack.components.evaluators import ContextRelevanceEvaluator, FaithfulnessEvaluator, SASEvaluator | ||
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter | ||
from haystack.components.writers import DocumentWriter | ||
from haystack.document_stores.types import DuplicatePolicy | ||
from haystack.evaluation import EvaluationRunResult | ||
from tqdm import tqdm | ||
|
||
from architectures.basic_rag import basic_rag | ||
|
||
embedding_model = "sentence-transformers/all-MiniLM-L6-v2" | ||
files_path = "datasets/MiniESGBench/" | ||
|
||
|
||
def indexing(): | ||
document_store = InMemoryDocumentStore() | ||
pipeline = Pipeline() | ||
pipeline.add_component("converter", PyPDFToDocument()) | ||
pipeline.add_component("cleaner", DocumentCleaner()) | ||
pipeline.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=128)) | ||
pipeline.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)) | ||
pipeline.add_component("embedder", SentenceTransformersDocumentEmbedder(embedding_model)) | ||
pipeline.connect("converter", "cleaner") | ||
pipeline.connect("cleaner", "splitter") | ||
pipeline.connect("splitter", "embedder") | ||
pipeline.connect("embedder", "writer") | ||
pdf_files = [files_path+"source_files/"+f_name for f_name in os.listdir(files_path+"source_files/")] | ||
pipeline.run({"converter": {"sources": pdf_files}}) | ||
|
||
return document_store | ||
|
||
|
||
def read_question_answers(): | ||
with open(files_path+"/rag_dataset.json", "r") as f: | ||
data = json.load(f) | ||
questions = [] | ||
contexts = [] | ||
answers = [] | ||
for entry in data['examples']: | ||
questions.append(entry['query']) | ||
contexts.append(entry['reference_contexts']) | ||
answers.append(entry['reference_answer']) | ||
|
||
return questions, contexts, answers | ||
|
||
|
||
def run_basic_rag(doc_store, questions_sample, answers_sample, contexts_sample): | ||
""" | ||
A function to run the basic rag model on a set of sample questions and answers | ||
""" | ||
|
||
rag = basic_rag(document_store=doc_store, embedding_model=embedding_model, top_k=2) | ||
|
||
predicted_answers = [] | ||
retrieved_contexts = [] | ||
for q in tqdm(questions_sample): | ||
response = rag.run( | ||
data={"query_embedder": {"text": q}, "prompt_builder": {"question": q}, "answer_builder": {"query": q}}) | ||
predicted_answers.append(response["answer_builder"]["answers"][0].data) | ||
retrieved_contexts.append([d.content for d in response['answer_builder']['answers'][0].documents]) | ||
|
||
context_relevance = ContextRelevanceEvaluator() | ||
faithfulness = FaithfulnessEvaluator() | ||
sas = SASEvaluator(model=embedding_model) | ||
sas.warm_up() | ||
results = { | ||
"context_relevance": context_relevance.run(questions_sample, retrieved_contexts), | ||
"faithfulness": faithfulness.run(questions_sample, retrieved_contexts, predicted_answers), | ||
"sas": sas.run(predicted_answers, answers_sample), | ||
} | ||
inputs = {'questions': questions_sample, "true_answers": answers_sample, "predicted_answers": predicted_answers} | ||
|
||
return EvaluationRunResult(run_name="basic_rag", inputs=inputs, results=results) | ||
|
||
|
||
def main(): | ||
doc_store = indexing() | ||
questions, contexts, answers = read_question_answers() | ||
|
||
limit = 5 | ||
questions_sample = random.sample(questions, limit) | ||
contexts_sample = random.sample(contexts, limit) | ||
answers_sample = random.sample(answers, limit) | ||
|
||
basic_rag_results = run_basic_rag(doc_store, questions_sample, answers_sample, contexts_sample) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters