-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Evals] Add retrieval completion fn and example (#656)
- Loading branch information
1 parent
c6412b0
commit c0dbe9d
Showing
2 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
""" | ||
Extending Completion Functions with Embeddings-based retrieval from a fetched dataset | ||
""" | ||
from ast import literal_eval | ||
from typing import Any, Optional, Union | ||
|
||
import numpy as np | ||
import openai | ||
import pandas as pd | ||
|
||
from evals.api import CompletionFn, CompletionResult | ||
from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt | ||
from evals.record import record_sampling | ||
from evals.registry import Registry | ||
|
||
|
||
def load_embeddings(embeddings_and_text_path: str): | ||
df = pd.read_csv(embeddings_and_text_path, converters={"embedding": literal_eval}) | ||
assert ( | ||
"text" in df.columns and "embedding" in df.columns | ||
), "The embeddings file must have columns named 'text' and 'embedding'" | ||
return df | ||
|
||
|
||
def find_top_k_closest_embeddings(embedded_prompt: list[float], embs: list[list[float]], k: int): | ||
# Normalize the embeddings | ||
norm_embedded_prompt = embedded_prompt / np.linalg.norm(embedded_prompt) | ||
norm_embs = embs / np.linalg.norm(embs, axis=1)[:, np.newaxis] | ||
|
||
# Calculate cosine similarity | ||
cosine_similarities = np.dot(norm_embs, norm_embedded_prompt) | ||
|
||
# Get the indices of the top k closest embeddings | ||
top_k_indices = np.argsort(cosine_similarities)[-k:] | ||
|
||
return top_k_indices[::-1] | ||
|
||
|
||
DEFAULT_RETRIEVAL_TEMPLATE = "Use the provided context to answer the question. " | ||
|
||
|
||
class RetrievalCompletionResult(CompletionResult): | ||
def __init__(self, response: str) -> None: | ||
self.response = response | ||
|
||
def get_completions(self) -> list[str]: | ||
return [self.response.strip()] | ||
|
||
|
||
class RetrievalCompletionFn(CompletionFn): | ||
""" | ||
This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
completion_fn: str, | ||
embeddings_and_text_path: str, | ||
retrieval_template: str = DEFAULT_RETRIEVAL_TEMPLATE, | ||
k: int = 4, | ||
embedding_model: str = "text-embedding-ada-002", | ||
registry: Optional[Registry] = None, | ||
registry_path: Optional[str] = None, | ||
**_kwargs: Any | ||
) -> None: | ||
""" | ||
Args: | ||
retrieval_template: The template to use for the retrieval. The task prompt will be added to the end of this template. | ||
k: The number of docs to retrieve from the dataset. | ||
completion_fn: The completion function to use for the retrieval. | ||
embeddings_and_text_path: The path to a CSV containing "text" and "embedding" columns. | ||
registry: Upstream callers may pass in a registry to use. | ||
registry_path: The path to a registry file to add to default registry. | ||
_kwargs: Additional arguments to pass to the completion function instantiation. | ||
""" | ||
registry = Registry() if not registry else registry | ||
if registry_path: | ||
registry.add_registry_paths(registry_path) | ||
|
||
self.embeddings_df = load_embeddings(embeddings_and_text_path) | ||
|
||
self.embedding_model = embedding_model | ||
self.k = k | ||
|
||
self.retrieval_template = retrieval_template | ||
self.completion_fn_instance = registry.make_completion_fn(completion_fn) | ||
|
||
def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCompletionResult: | ||
""" | ||
Args: | ||
prompt: The prompt to complete, in either text string or Chat format. | ||
kwargs: Additional arguments to pass to the completion function call method. | ||
""" | ||
# Embed the prompt | ||
embedded_prompt = openai.Embedding.create( | ||
model=self.embedding_model, input=CompletionPrompt(prompt).to_formatted_prompt() | ||
)["data"][0]["embedding"] | ||
|
||
embs = self.embeddings_df["embedding"].to_list() | ||
|
||
# Compute the cosine similarity between the prompt and the embeddings | ||
topk = " ".join( | ||
self.embeddings_df.iloc[ | ||
find_top_k_closest_embeddings(embedded_prompt, embs, k=self.k) | ||
].text.values | ||
) | ||
|
||
prompt = ChatCompletionPrompt(prompt).to_formatted_prompt() | ||
retrieval_prompt = [{"role": "system", "content": self.retrieval_template + topk}] + prompt | ||
|
||
answer = self.completion_fn_instance(prompt=retrieval_prompt, **kwargs).get_completions()[0] | ||
record_sampling(prompt=retrieval_prompt, sampled=answer) | ||
return RetrievalCompletionResult(answer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\"\"\"\n", | ||
"We show here how to use the retrieval completion function to add context from documents to any OpenAI Evals task\n", | ||
"The toy example here will be to augment our Born-First task with a dataset of presidential birthdays\n", | ||
"\"\"\"\n", | ||
"\n", | ||
"# Download the dataset manually, or use curl\n", | ||
"!curl -O https://people.math.sc.edu/Burkardt/datasets/presidents/president_birthdays.csv" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import openai\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"df = pd.read_csv(\"president_birthdays.csv\").rename(columns={\" \\\"Name\\\"\": \"Name\", \" \\\"Month\\\"\": \"Month\", \" \\\"Day\\\"\": \"Day\", \" \\\"Year\\\"\": \"Year\"}).set_index(\"Index\")\n", | ||
"df[\"text\"] = df.apply(lambda r: f\"{r['Name']} was born on {r['Month']}/{r['Day']}/{r['Year']}\", axis=1)\n", | ||
"display(df.head())\n", | ||
"\n", | ||
"def embed(text):\n", | ||
" return openai.Embedding.create(\n", | ||
" model=\"text-embedding-ada-002\",\n", | ||
" input=text\n", | ||
" )[\"data\"][0][\"embedding\"]\n", | ||
"\n", | ||
"df[\"embedding\"] = df['text'].apply(embed)\n", | ||
"df[[\"text\", \"embedding\"]].to_csv(\"presidents_embeddings.csv\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\"\"\"\n", | ||
"We create a registry entry here in code. Notice we set number of retrieved documents k=2.\n", | ||
"\"\"\"\n", | ||
"\n", | ||
"registry_yaml = f\"\"\"\n", | ||
"retrieval/presidents/gpt-3.5-turbo:\n", | ||
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n", | ||
" args:\n", | ||
" completion_fn: gpt-3.5-turbo\n", | ||
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n", | ||
" k: 2\n", | ||
"\n", | ||
"retrieval/presidents/cot/gpt-3.5-turbo:\n", | ||
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n", | ||
" args:\n", | ||
" completion_fn: cot/gpt-3.5-turbo\n", | ||
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n", | ||
" k: 2\n", | ||
"\"\"\".strip()\n", | ||
"\n", | ||
"# Replace with path to your registry\n", | ||
"os.makedirs(\"completion_fns\", exist_ok=True)\n", | ||
"with open(\"completion_fns/retrieval.yaml\", \"w\") as f:\n", | ||
" f.write(registry_yaml)\n", | ||
"\n", | ||
"# GPT-3.5-turbo base: accuracy 0.7\n", | ||
"!oaieval gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n", | ||
"\n", | ||
"# GPT-3.5-turbo with retrieval: accuracy 0.9 -> The failure mode here is the retrieved president is incorrect: Andrew Johnson vs Andrew Jackson\n", | ||
"!oaieval retrieval/presidents/gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n", | ||
"\n", | ||
"# GPT-3.5-turbo with retrieval and chain-of-thought: accuracy 1.0\n", | ||
"!oaieval retrieval/presidents/cot/gpt-3.5-turbo born-first --max_samples 10 --registry_path ." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.9" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |