Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update evaluate.py #1930

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Update evaluate.py
  • Loading branch information
natedem30 committed May 15, 2024
commit 8e3a67915424c0f32bb5a49912b3e5e33f012f6d
45 changes: 31 additions & 14 deletions giskard/rag/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,30 @@ def evaluate(
RAGReport
The report of the evaluation.
"""
if testset is None and knowledge_base is None:
raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.")

if testset is None and not isinstance(answer_fn, Sequence):
raise ValueError(
validate_inputs(answer_fn, knowledge_base, testset)
testset = testset or generate_testset(knowledge_base)
answers = retrieve_answers(answer_fn, testset)
llm_client = llm_client or get_default_client()
metrics = get_metrics(metrics, llm_client, agent_description)
metrics_results = compute_metrics(metrics, testset, answers)
report = get_report(testset, answers, metrics_results, knowledge_base)
add_recommendation(report, llm_client, metrics)
track_analytics(report, testset, knowledge_base, agent_description, metrics)

return report

def validate_inputs(answer_fn, knowledge_base, testset):
if testset is None:
if knowledge_base is None:
raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.")
if not isinstance(answer_fn, Sequence):
raise ValueError(
"If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers."
)

testset = generate_testset(knowledge_base)

# Check basic types, in case the user passed the params in the wrong order
if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase):
raise ValueError(
Expand All @@ -69,21 +85,19 @@ def evaluate(
f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?"
)

if testset is None:
testset = generate_testset(knowledge_base)

answers = answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset)

llm_client = llm_client or get_default_client()
def retrieve_answers(answer_fn, testset):
return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset)

# @TODO: improve this
def get_metrics(metrics, llm_client, agent_description):
metrics = list(metrics) if metrics is not None else []
if not any(isinstance(metric, CorrectnessMetric) for metric in metrics):
# By default only correctness is computed as it is required to build the report
metrics.insert(
0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description)
)
return metrics

def compute_metrics(metrics, testset, answers):
metrics_results = defaultdict(dict)

for metric in metrics:
Expand All @@ -97,8 +111,12 @@ def evaluate(
total=len(answers),
):
metrics_results[sample["id"]].update(metric(sample, answer))
return metrics_results

report = RAGReport(testset, answers, metrics_results, knowledge_base)
def get_report(testset, answers, metrics_results, knowledge_base):
return RAGReport(testset, answers, metrics_results, knowledge_base)

def add_recommendation(report, llm_client, metrics):
recommendation = get_rag_recommendation(
report.topics,
report.correctness_by_question_type().to_dict()[metrics[0].name],
Expand All @@ -107,6 +125,7 @@ def evaluate(
)
report._recommendation = recommendation

def track_analytics(report, testset, knowledge_base, agent_description, metrics):
analytics.track(
"raget:evaluation",
{
Expand All @@ -117,8 +136,6 @@ def evaluate(
"correctness": report.correctness,
},
)
return report


def _compute_answers(answer_fn, testset):
answers = []
Expand Down