Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve context relevancy metric #7964

Merged
merged 22 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
68eab35
fixing tests
davidsbatista Jul 2, 2024
5c97fac
fixing tests
davidsbatista Jul 2, 2024
f0d7484
Merge branch 'main' into improve-context-relevancy-metric
davidsbatista Jul 2, 2024
5ea4bda
updating tests
davidsbatista Jul 2, 2024
3204739
updating tests
davidsbatista Jul 2, 2024
40ea113
updating docstring
davidsbatista Jul 2, 2024
a7b7b32
adding release notes
davidsbatista Jul 2, 2024
f238145
making the insufficient information more robust
davidsbatista Jul 4, 2024
61ad27f
updating docstring and release notes
davidsbatista Jul 5, 2024
93b309f
empty list instead of informative string
davidsbatista Jul 5, 2024
4facb67
Merge branch 'main' into improve-context-relevancy-metric
davidsbatista Jul 11, 2024
ed49292
Merge branch 'main' into improve-context-relevancy-metric
davidsbatista Jul 12, 2024
bf2181d
Merge branch 'main' into improve-context-relevancy-metric
davidsbatista Jul 19, 2024
af68087
Update haystack/components/evaluators/context_relevance.py
davidsbatista Jul 22, 2024
f15eb0f
Update haystack/components/evaluators/context_relevance.py
davidsbatista Jul 22, 2024
9c23ceb
fixing tests
davidsbatista Jul 22, 2024
e366b27
Update haystack/components/evaluators/context_relevance.py
davidsbatista Jul 22, 2024
64123b1
reverting commit
davidsbatista Jul 22, 2024
ef1ef4b
reverting again commit
davidsbatista Jul 22, 2024
c358baf
fixing docstrings
davidsbatista Jul 22, 2024
9c21997
removing deprecation warning
davidsbatista Jul 22, 2024
308967c
removing warning import
davidsbatista Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 50 additions & 44 deletions haystack/components/evaluators/context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from statistics import mean
from typing import Any, Dict, List, Optional

from numpy import mean as np_mean

from haystack import component, default_from_dict, default_to_dict
from haystack.components.evaluators.llm_evaluator import LLMEvaluator
from haystack.utils import Secret, deserialize_secrets_inplace
Expand All @@ -16,12 +15,9 @@
{
"inputs": {
"questions": "What is the capital of Germany?",
"contexts": ["Berlin is the capital of Germany and was founded in 1244."],
},
"outputs": {
"statements": ["Berlin is the capital of Germany.", "Berlin was founded in 1244."],
"statement_scores": [1, 0],
"contexts": ["Berlin is the capital of Germany. Berlin and was founded in 1244."],
},
"outputs": {"relevant_statements": ["Berlin is the capital of Germany."]},
},
{
"inputs": {
Expand All @@ -32,19 +28,11 @@
"Madrid is the capital of Spain.",
],
},
"outputs": {
"statements": [
"Berlin is the capital of Germany.",
"Berlin was founded in 1244.",
"Europe is a continent with 44 countries.",
"Madrid is the capital of Spain.",
],
"statement_scores": [0, 0, 0, 0],
},
"outputs": {"relevant_statements": []},
},
{
"inputs": {"questions": "What is the capital of Italy?", "contexts": ["Rome is the capital of Italy."]},
"outputs": {"statements": ["Rome is the capital of Italy."], "statement_scores": [1]},
"outputs": {"relevant_statements": ["Rome is the capital of Italy."]},
},
]

Expand All @@ -54,36 +42,57 @@ class ContextRelevanceEvaluator(LLMEvaluator):
"""
Evaluator that checks if a provided context is relevant to the question.

An LLM breaks up the context into multiple statements and checks whether each statement
An LLM breaks up a context into multiple statements and checks whether each statement
is relevant for answering a question.
The final score for the context relevance is a number from 0.0 to 1.0. It represents the proportion of
statements that can be inferred from the provided contexts.
The score for each context is either binary score of 1 or 0, where 1 indicates that the context is relevant
to the question and 0 indicates that the context is not relevant.
The evaluator also provides the relevant statements from the context and an average score over all the provided
input questions contexts pairs.

Usage example:
```python
from haystack.components.evaluators import ContextRelevanceEvaluator

questions = ["Who created the Python language?"]
questions = ["Who created the Python language?", "Why does Java needs a JVM?", "Is C++ better than Python?"]
contexts = [
[(
"Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
"language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
"programmers write clear, logical code for both small and large-scale software projects."
)],
[(
"Java is a high-level, class-based, object-oriented programming language that is designed to have as few "
"implementation dependencies as possible. The JVM has two primary functions: to allow Java programs to run
on any device or operating system (known as the "write once, run anywhere" principle), and to manage and
optimize program memory.
)],
[(
"C++ is a general-purpose programming language created by Bjarne Stroustrup as an extension of the C "
"programming language."
)],
]

evaluator = ContextRelevanceEvaluator()
result = evaluator.run(questions=questions, contexts=contexts)
print(result["score"])
# 1.0
# 0.67
print(result["individual_scores"])
# [1.0]
# [1,1,0]
print(result["results"])
# [{
# 'statements': ['Python, created by Guido van Rossum in the late 1980s.'],
# 'statement_scores': [1],
# 'relevant_statements': ['Python, created by Guido van Rossum in the late 1980s.'],
# 'score': 1.0
# },
# {
# 'relevant_statements': ['The JVM has two primary functions: to allow Java programs to run on any device or
# operating system (known as the "write once, run anywhere" principle), and to manage and
# optimize program memory'],
# 'score': 1.0
# }]
# },
# {
# 'relevant_statements': [],
# 'score': 0.0
# }]
```
"""

Expand Down Expand Up @@ -111,8 +120,8 @@ def __init__(
"questions": "What is the capital of Italy?", "contexts": ["Rome is the capital of Italy."],
},
"outputs": {
"statements": ["Rome is the capital of Italy."],
"statement_scores": [1],
"relevant_statements": ["Rome is the capital of Italy."],
"score": 1,
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
},
}]
:param progress_bar:
Expand All @@ -128,15 +137,14 @@ def __init__(
Whether to raise an exception if the API call fails.

"""

self.instructions = (
"Your task is to judge how relevant the provided context is for answering a question. "
"First, please extract statements from the provided context. "
"Second, calculate a relevance score for each statement in the context. "
"The score is 1 if the statement is relevant to answer the question or 0 if it is not relevant. "
"Each statement should be scored individually."
"Please extract only sentences from the provided context which are absolutely relevant and "
"required to answer the following question. If no relevant sentences are found, or if you "
"believe the question cannot be answered from the given context, return an empty list, example: []"
)
self.inputs = [("questions", List[str]), ("contexts", List[List[str]])]
self.outputs = ["statements", "statement_scores"]
self.outputs = ["relevant_statements"]
self.examples = examples or _DEFAULT_EXAMPLES
self.api = api
self.api_key = api_key
Expand All @@ -161,7 +169,7 @@ def __init__(
progress_bar=progress_bar,
)

@component.output_types(individual_scores=List[int], score=float, results=List[Dict[str, Any]])
@component.output_types(score=float, results=List[Dict[str, Any]])
def run(self, questions: List[str], contexts: List[List[str]]) -> Dict[str, Any]:
"""
Run the LLM evaluator.
Expand All @@ -173,24 +181,22 @@ def run(self, questions: List[str], contexts: List[List[str]]) -> Dict[str, Any]
:returns:
A dictionary with the following outputs:
- `score`: Mean context relevance score over all the provided input questions.
- `individual_scores`: A list of context relevance scores for each input question.
- `results`: A list of dictionaries with `statements` and `statement_scores` for each input context.
- `results`: A list of dictionaries with `relevant_statements` and `score` for each input context.
"""
result = super(ContextRelevanceEvaluator, self).run(questions=questions, contexts=contexts)

# calculate average statement relevance score per query
for idx, res in enumerate(result["results"]):
if res is None:
result["results"][idx] = {"statements": [], "statement_scores": [], "score": float("nan")}
result["results"][idx] = {"relevant_statements": [], "score": float("nan")}
continue
if not res["statements"]:
res["score"] = 0
if len(res["relevant_statements"]) > 0:
res["score"] = 1
else:
res["score"] = np_mean(res["statement_scores"])
res["score"] = 0

# calculate average context relevance score over all queries
result["score"] = np_mean([res["score"] for res in result["results"]])
result["individual_scores"] = [res["score"] for res in result["results"]]
result["score"] = mean([res["score"] for res in result["results"]])
result["individual_scores"] = [res["score"] for res in result["results"]] # useful for the EvaluationRunResult

return result

Expand Down
1 change: 1 addition & 0 deletions haystack/evaluation/eval_run_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,5 @@ def comparative_individual_scores_report(
pipe_a_df.columns = [f"{this_name}_{col}" if col not in ignore else col for col in pipe_a_df.columns] # type: ignore

results_df = pd_concat([pipe_a_df, pipe_b_df], axis=1)

return results_df
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---

upgrade:
- |
The `ContextRelevanceEvaluator` now returns a list of relevant sentences for each context, instead of all the sentences in a context.
Also, a score of 1 is now returned if a relevant sentence is found, and 0 otherwise.
104 changes: 27 additions & 77 deletions test/components/evaluators/test_context_relevance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,19 @@ def test_init_default(self, monkeypatch):
assert component.api == "openai"
assert component.generator.client.api_key == "test-api-key"
assert component.instructions == (
"Your task is to judge how relevant the provided context is for answering a question. "
"First, please extract statements from the provided context. "
"Second, calculate a relevance score for each statement in the context. "
"The score is 1 if the statement is relevant to answer the question or 0 if it is not relevant. "
"Each statement should be scored individually."
"Please extract only sentences from the provided context which are absolutely relevant and "
"required to answer the following question. If no relevant sentences are found, or if you "
"believe the question cannot be answered from the given context, return an empty list, example: []"
)
assert component.inputs == [("questions", List[str]), ("contexts", List[List[str]])]
assert component.outputs == ["statements", "statement_scores"]
assert component.outputs == ["relevant_statements"]
assert component.examples == [
{
"inputs": {
"questions": "What is the capital of Germany?",
"contexts": ["Berlin is the capital of Germany and was founded in 1244."],
},
"outputs": {
"statements": ["Berlin is the capital of Germany.", "Berlin was founded in 1244."],
"statement_scores": [1, 0],
"contexts": ["Berlin is the capital of Germany. Berlin and was founded in 1244."],
},
"outputs": {"relevant_statements": ["Berlin is the capital of Germany."]},
},
{
"inputs": {
Expand All @@ -48,19 +43,11 @@ def test_init_default(self, monkeypatch):
"Madrid is the capital of Spain.",
],
},
"outputs": {
"statements": [
"Berlin is the capital of Germany.",
"Berlin was founded in 1244.",
"Europe is a continent with 44 countries.",
"Madrid is the capital of Spain.",
],
"statement_scores": [0, 0, 0, 0],
},
"outputs": {"relevant_statements": []},
},
{
"inputs": {"questions": "What is the capital of Italy?", "contexts": ["Rome is the capital of Italy."]},
"outputs": {"statements": ["Rome is the capital of Italy."], "statement_scores": [1]},
"outputs": {"relevant_statements": ["Rome is the capital of Italy."]},
},
]

Expand Down Expand Up @@ -133,9 +120,9 @@ def test_run_calculates_mean_score(self, monkeypatch):

def generator_run(self, *args, **kwargs):
if "Football" in kwargs["prompt"]:
return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']}
return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']}
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}
return {"replies": ['{"relevant_statements": [], "score": 0}']}

monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)

Expand All @@ -148,20 +135,19 @@ def generator_run(self, *args, **kwargs):
"Messi, drawing a followership of more than 4 billion people."
],
[
"Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
"language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
"Python is design philosophy emphasizes code readability, and its language constructs aim to help "
"programmers write clear, logical code for both small and large-scale software projects."
],
]
results = component.run(questions=questions, contexts=contexts)

print(results)

assert results == {
"individual_scores": [0.5, 1],
"results": [
{"score": 0.5, "statement_scores": [1, 0], "statements": ["a", "b"]},
{"score": 1, "statement_scores": [1, 1], "statements": ["c", "d"]},
],
"score": 0.75,
"results": [{"score": 1, "relevant_statements": ["a", "b"]}, {"score": 0, "relevant_statements": []}],
"score": 0.5,
"meta": None,
"individual_scores": [1, 0],
}

def test_run_no_statements_extracted(self, monkeypatch):
Expand All @@ -170,9 +156,9 @@ def test_run_no_statements_extracted(self, monkeypatch):

def generator_run(self, *args, **kwargs):
if "Football" in kwargs["prompt"]:
return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']}
return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']}
else:
return {"replies": ['{"statements": [], "statement_scores": []}']}
return {"replies": ['{"relevant_statements": [], "score": 0}']}

monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)

Expand All @@ -188,13 +174,10 @@ def generator_run(self, *args, **kwargs):
]
results = component.run(questions=questions, contexts=contexts)
assert results == {
"individual_scores": [0.5, 0],
"results": [
{"score": 0.5, "statement_scores": [1, 0], "statements": ["a", "b"]},
{"score": 0, "statement_scores": [], "statements": []},
],
"score": 0.25,
"results": [{"score": 1, "relevant_statements": ["a", "b"]}, {"score": 0, "relevant_statements": []}],
"score": 0.5,
"meta": None,
"individual_scores": [1, 0],
}

def test_run_missing_parameters(self, monkeypatch):
Expand All @@ -211,7 +194,7 @@ def generator_run(self, *args, **kwargs):
if "Python" in kwargs["prompt"]:
raise Exception("OpenAI API request failed.")
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}
return {"replies": ['{"relevant_statements": ["c", "d"], "score": 1}']}

monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)

Expand All @@ -232,14 +215,8 @@ def generator_run(self, *args, **kwargs):
results = component.run(questions=questions, contexts=contexts)

assert math.isnan(results["score"])

assert results["individual_scores"][0] == 1.0
assert math.isnan(results["individual_scores"][1])

assert results["results"][0] == {"statements": ["c", "d"], "statement_scores": [1, 1], "score": 1.0}

assert results["results"][1]["statements"] == []
assert results["results"][1]["statement_scores"] == []
assert results["results"][0] == {"relevant_statements": ["c", "d"], "score": 1}
assert results["results"][1]["relevant_statements"] == []
assert math.isnan(results["results"][1]["score"])

@pytest.mark.skipif(
Expand All @@ -254,39 +231,12 @@ def test_live_run(self):
evaluator = ContextRelevanceEvaluator()
result = evaluator.run(questions=questions, contexts=contexts)

required_fields = {"individual_scores", "results", "score"}
required_fields = {"results"}
assert all(field in result for field in required_fields)
nested_required_fields = {"score", "statement_scores", "statements"}
nested_required_fields = {"score", "relevant_statements"}
assert all(field in result["results"][0] for field in nested_required_fields)

assert "meta" in result
assert "prompt_tokens" in result["meta"][0]["usage"]
assert "completion_tokens" in result["meta"][0]["usage"]
assert "total_tokens" in result["meta"][0]["usage"]

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_all_statements_are_scored(self):
from haystack.components.evaluators import ContextRelevanceEvaluator

questions = ["Who created the Python language?"]
contexts = [
[
"Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
"language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
"programmers write clear, logical code for both small and large-scale software projects.",
"Java is a high-level, class-based, object-oriented programming language which allows you to write once, "
"run anywhere, meaning that compiled Java code can run on all platforms that support Java without the "
"need for recompilation.",
"Scala is a high-level, statically typed programming language.",
]
]

evaluator = ContextRelevanceEvaluator()
result = evaluator.run(questions=questions, contexts=contexts)

assert len(result["results"][0]["statements"]) == 4
assert len(result["results"][0]["statement_scores"]) == 4
Loading