Skip to content

Commit

Permalink
Adding Domain Expert Evaluator (#5)
Browse files Browse the repository at this point in the history
* typo on openai file name

* typo on openai file name

* call super on document evaluators

* added domain_expert document prompt

* added domain_expert document prompt

* fixing rdnam answer parsing

* fixing rdnam answer parsing

* fixing rdnam answer parsing

* fixing rdnam answer parsing

* move openai, better progress bars for doc evaluator

* domain_expert fixes

* linting

* doc evaluator now implements a single function per qid-did pair

* doc evaluator now implements a single function per qid-did pair

* flake8

* less specific prompt for domain expert
  • Loading branch information
ArthurCamara authored Oct 25, 2023
1 parent bdd29e0 commit c4c785a
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 97 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ python -m build

### ✅ TODO
- [ ] Add option to few-shot examples
- [x] Publish on PyPi
- [ ] Add custom types
- [ ] Testing!
- [ ] Add CI/CD for publishing
- [x] Publish on PyPi
- [x] Add more document evaluators (Microsoft)
- [x] Split Elo evaluator
- [x] Install as standalone CLI
Expand Down
21 changes: 18 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,20 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: System :: Benchmark",
]
dependencies = ["openai", "tenacity", "typer"]
dependencies = ["openai", "tenacity", "typer", "numpy"]

[project.optional-dependencies]
cli = ["typer[all]"]
dev = ["bandit==1.7.5", "black==23.10.0", "isort==5.12.0", "flake8==6.1.0", "flake8-black==0.3.6", "flake8-isort==6.1.0", "mypy==1.6.1"]
dev = [
"bandit==1.7.5",
"black==23.10.0",
"isort==5.12.0",
"flake8==6.1.0",
"flake8-black==0.3.6",
"flake8-isort==6.1.0",
"mypy==1.6.1",
"Flake8-pyproject==1.2.3",
]

[project.scripts]
ragelo = "ragelo.cli:app"
Expand All @@ -45,9 +54,15 @@ profile = "black"

[tool.mypy]
python_version = "3.11"
ignore_missing_imports = true
show_column_numbers = true
namespace_packages = true
exclude = ["build/", "dist/", "venv/"]

[tool.flake8]
ignore = ['E501', "W503"]
per-file-ignores = ['__init__.py:F401,F403']
exclude = ["build/", "dist/", "venv/"]


[tool.setuptools-git-versioning]
enabled = true
2 changes: 1 addition & 1 deletion ragelo/answer_evaluators/base_answer_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Set, Type

from ragelo.logger import logger
from ragelo.opeanai_client import OpenAiClient, set_credentials_from_file
from ragelo.utils.openai_client import OpenAiClient, set_credentials_from_file


class AnswerEvaluator:
Expand Down
1 change: 1 addition & 0 deletions ragelo/doc_evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ragelo.doc_evaluators.domain_expert import *
from ragelo.doc_evaluators.rdnam_evaluator import *
from ragelo.doc_evaluators.reasoner_evaluator import *
173 changes: 114 additions & 59 deletions ragelo/doc_evaluators/base_doc_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import os
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, Type
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Dict, List, Set, Tuple, Type

from tenacity import RetryError

from ragelo.logger import logger
from ragelo.opeanai_client import OpenAiClient, set_credentials_from_file
from ragelo.utils.openai_client import OpenAiClient, set_credentials_from_file


class DocumentEvaluator:
Expand All @@ -29,69 +31,80 @@ def __init__(
self.output_file = output_file
self.queries = self._load_queries(query_path)
self.documents = self._load_documents(documents_path)
if verbose:
logger.setLevel("INFO")

if credentials_file:
set_credentials_from_file(credentials_file)

self.openai_client = OpenAiClient(model=model_name)
self.progress_bar: Callable = nullcontext
try:
from rich.progress import Progress

def get_answers(self):
self.progress_bar = partial(Progress, transient=True)
self.rich = True
except ImportError:
self.rich = False

def get_answers(self) -> Dict[str, Dict[str, Any]]:
"""Runs the evaluator and saves the results to a file"""
skip_docs = set()
if os.path.isfile(self.output_file) and not self.force:
for line in csv.reader(open(self.output_file)):
qid, did, answer = line
skip_docs.add((qid, did))
if self.force and os.path.isfile(self.output_file):
logger.warning(f"Removing existing {self.output_file}!")
os.remove(self.output_file)
if len(skip_docs) > 0:
logger.warning(
f"Skipping {len(skip_docs)} documents already annotated! "
"If you want to reannotate them, please use the --force flag"

use_bar = self.verbose and self.rich
skip_docs = self._get_skip_docs()
answers: Dict[str, Dict[str, Any]] = defaultdict(lambda: dict())
with self.progress_bar() as progress:
# If we are using rich's progress bar, initialize a task for the queries
q_progress = q_progress = (
progress.add_task(
"[bold blue]Annotating Documents", total=len(self.queries)
)
if use_bar and progress
else None
)
q_iterator = self.queries
if self.verbose:
try:
from rich.progress import track

q_iterator = track(self.queries, description="Annotating Documents")
except ImportError:
pass
for qid in q_iterator:
for did in self.documents[qid]:
if (qid, did) in skip_docs:
logger.debug(f"Skipping {qid} {did}")
continue
message = self._build_message(qid, did)
try:
answer = self.openai_client(message)
answer = self._process_answer(answer)
except RetryError:
logger.warning(f"Failed to fetch answers for document {qid} {did}")
continue
except ValueError:
logger.warning(f"Failed to parse answer for document {qid} {did}")
continue
if self.verbose:
logger.info(
"[bold cyan]Query [/bold cyan]: "
f"[not bold cyan]{self.queries[qid]}[/not bold cyan]"
)
logger.info(f"[bold cyan]Document ID [/bold cyan]: {did}")
logger.info(
"[bold cyan]Evaluation [/bold cyan]: "
f"[not bold]{answer}[/not bold]"
for qid in self.queries:
d_progress = (
progress.add_task(
f"[bold white]{qid}", total=len(self.documents[qid])
)
logger.info("")
if not os.path.isfile(self.output_file):
with open(self.output_file, "w") as f:
writer = csv.writer(f)
writer.writerow(["query_id", "did", "answer"])

with open(self.output_file, "a") as f:
writer = csv.writer(f)
writer.writerow([qid, did, answer])
if use_bar and progress
else None
)
for did in self.documents[qid]:
if (qid, did) in skip_docs:
logger.debug(f"Skipping {qid} {did}")
continue

try:
answer = self._process_single_answer(qid, did)
except (RetryError, ValueError):
continue
self._print_response(qid, did, answer)
self._dump_response(qid, did, answer)
answers[qid][did] = answer
if progress and d_progress:
progress.update(d_progress, advance=1, refresh=True)
if progress and q_progress:
if d_progress:
progress.stop_task(d_progress)
progress.update(q_progress, advance=1, refresh=True)
return answers

def _process_single_answer(self, qid: str, did: str) -> str:
"""Submites a single query-document pair to the LLM and returns the answer.
Override this method to implement a custom evaluator (e.g., two-shot)
"""
message = self._build_message(qid, did)
try:
answer = self.openai_client(message)
answer = self._process_answer(answer)
except RetryError as e:
logger.warning(f"Failed to FETCH answers for {qid} {did}")
raise e
except ValueError as e:
logger.warning(f"Failed to PARSE answer for {qid} {did}")
raise e
return answer

@abstractmethod
def _build_message(self, qid: str, did: str) -> str:
Expand Down Expand Up @@ -153,7 +166,49 @@ def _load_documents(self, documents_path: str) -> Dict[str, Dict[str, str]]:
logger.info(f"Loaded {len(rows)} documents")
return rows

def __load_from_csv(self, file_path: str) -> Dict[str, str]:
def _get_skip_docs(self) -> Set[Tuple[str, str]]:
skip_docs = set()
if os.path.isfile(self.output_file) and not self.force:
for line in csv.reader(open(self.output_file)):
qid, did, answer = line
skip_docs.add((qid, did))
if self.force and os.path.isfile(self.output_file):
logger.warning(f"Removing existing {self.output_file}!")
os.remove(self.output_file)
if len(skip_docs) > 0:
logger.warning(
f"Skipping {len(skip_docs)} documents already annotated! "
"If you want to reannotate them, please use the --force flag"
)
return skip_docs

def _print_response(self, qid: str, did: str, answer: str) -> None:
logger.info(
"[bold cyan]Query [/bold cyan]: "
f"[not bold cyan]{self.queries[qid]}[/not bold cyan]"
)
logger.info(f"[bold cyan]Document ID [/bold cyan]: {did}")
logger.info(
"[bold cyan]Evaluation [/bold cyan]: " f"[not bold]{answer}[/not bold]"
)
logger.info("")

def _dump_response(
self, qid: str, did: str, answer: str | List[str], file: str | None = None
) -> None:
output_file = file if file else self.output_file
if not os.path.isfile(output_file):
with open(output_file, "w") as f:
writer = csv.writer(f)
writer.writerow(["query_id", "did", "answer"])

with open(output_file, "a") as f:
writer = csv.writer(f)
if isinstance(answer, List):
answer = "\n".join(answer)
writer.writerow([qid, did, answer])

def _load_from_csv(self, file_path: str) -> Dict[str, str]:
"""extra content from a CSV file"""
contents = {}
for line in csv.reader(open(file_path, "r")):
Expand Down Expand Up @@ -182,7 +237,7 @@ def inner_wrapper(
return inner_wrapper

@classmethod
def create(cls, evaluator_name: str, **kwargs) -> DocumentEvaluator:
def create(cls, evaluator_name: str, *args, **kwargs) -> DocumentEvaluator:
if evaluator_name not in cls.registry:
raise ValueError(f"Unknown evaluator {evaluator_name}")
return cls.registry[evaluator_name](prompt_name=evaluator_name, **kwargs)
return cls.registry[evaluator_name](prompt_name=evaluator_name, *args, **kwargs)
Loading

0 comments on commit c4c785a

Please sign in to comment.