-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Evals] Add retrieval completion fn and example #656
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
Extending Completion Functions with Embeddings-based retrieval from a fetched dataset | ||
""" | ||
from ast import literal_eval | ||
from typing import Any, Optional, Union | ||
|
||
import numpy as np | ||
import openai | ||
import pandas as pd | ||
|
||
from evals.api import CompletionFn, CompletionResult | ||
from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt | ||
from evals.record import record_sampling | ||
from evals.registry import Registry | ||
|
||
|
||
def load_embeddings(embeddings_and_text_path: str): | ||
df = pd.read_csv(embeddings_and_text_path, converters={"embedding": literal_eval}) | ||
assert ( | ||
"text" in df.columns and "embedding" in df.columns | ||
), "The embeddings file must have columns named 'text' and 'embedding'" | ||
return df | ||
|
||
|
||
def find_top_k_closest_embeddings(embedded_prompt: list[float], embs: list[list[float]], k: int): | ||
# Normalize the embeddings | ||
norm_embedded_prompt = embedded_prompt / np.linalg.norm(embedded_prompt) | ||
norm_embs = embs / np.linalg.norm(embs, axis=1)[:, np.newaxis] | ||
|
||
# Calculate cosine similarity | ||
cosine_similarities = np.dot(norm_embs, norm_embedded_prompt) | ||
|
||
# Get the indices of the top k closest embeddings | ||
top_k_indices = np.argsort(cosine_similarities)[-k:] | ||
|
||
return top_k_indices[::-1] | ||
|
||
|
||
DEFAULT_RETRIEVAL_TEMPLATE = "Use the provided context to answer the question. " | ||
|
||
|
||
class RetrievalCompletionResult(CompletionResult): | ||
def __init__(self, response: str) -> None: | ||
self.response = response | ||
|
||
def get_completions(self) -> list[str]: | ||
return [self.response.strip()] | ||
|
||
|
||
class RetrievalCompletionFn(CompletionFn): | ||
""" | ||
This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
completion_fn: str, | ||
embeddings_and_text_path: str, | ||
retrieval_template: str = DEFAULT_RETRIEVAL_TEMPLATE, | ||
k: int = 4, | ||
registry: Optional[Registry] = None, | ||
registry_path: Optional[str] = None, | ||
**_kwargs: Any | ||
) -> None: | ||
""" | ||
Args: | ||
retrieval_template: The template to use for the retrieval. The task prompt will be added to the end of this template. | ||
k: The number of docs to retrieve from the dataset. | ||
completion_fn: The completion function to use for the retrieval. | ||
embeddings_and_text_path: The path to a CSV containing "text" and "embedding" columns. | ||
registry: Upstream callers may pass in a registry to use. | ||
registry_path: The path to a registry file to add to default registry. | ||
_kwargs: Additional arguments to pass to the completion function instantiation. | ||
""" | ||
registry = Registry() if not registry else registry | ||
if registry_path: | ||
registry.add_registry_paths(registry_path) | ||
|
||
self.embeddings_df = load_embeddings(embeddings_and_text_path) | ||
|
||
self.k = k | ||
|
||
self.retrieval_template = retrieval_template | ||
self.completion_fn_instance = registry.make_completion_fn(completion_fn) | ||
|
||
def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCompletionResult: | ||
""" | ||
Args: | ||
prompt: The prompt to complete, in either text string or Chat format. | ||
kwargs: Additional arguments to pass to the completion function call method. | ||
""" | ||
# Embed the prompt | ||
embedded_prompt = openai.Embedding.create( | ||
model="text-embedding-ada-002", input=CompletionPrompt(prompt).to_formatted_prompt() | ||
)["data"][0]["embedding"] | ||
|
||
embs = self.embeddings_df["embedding"].to_list() | ||
|
||
# Compute the cosine similarity between the prompt and the embeddings | ||
topk = " ".join( | ||
self.embeddings_df.iloc[ | ||
find_top_k_closest_embeddings(embedded_prompt, embs, k=self.k) | ||
].text.values | ||
) | ||
|
||
prompt = ChatCompletionPrompt(prompt).to_formatted_prompt() | ||
retrieval_prompt = [{"role": "system", "content": self.retrieval_template + topk}] + prompt | ||
|
||
answer = self.completion_fn_instance(prompt=retrieval_prompt, **kwargs).get_completions()[0] | ||
record_sampling(prompt=retrieval_prompt, sampled=answer) | ||
return RetrievalCompletionResult(answer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\"\"\"\n", | ||
"We show here how to use the retrieval completion function to add context from documents to any OpenAI Evals task\n", | ||
"The toy example here will be to augment our Born-First task with a dataset of presidential birthdays\n", | ||
"\"\"\"\n", | ||
"\n", | ||
"# Download the dataset manually, or use curl\n", | ||
"!curl -O https://people.math.sc.edu/Burkardt/datasets/presidents/president_birthdays.csv" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import openai\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"df = pd.read_csv(\"president_birthdays.csv\").rename(columns={\" \\\"Name\\\"\": \"Name\", \" \\\"Month\\\"\": \"Month\", \" \\\"Day\\\"\": \"Day\", \" \\\"Year\\\"\": \"Year\"}).set_index(\"Index\")\n", | ||
"df[\"text\"] = df.apply(lambda r: f\"{r['Name']} was born on {r['Month']}/{r['Day']}/{r['Year']}\", axis=1)\n", | ||
"display(df.head())\n", | ||
"\n", | ||
"def embed(text):\n", | ||
" return openai.Embedding.create(\n", | ||
" model=\"text-embedding-ada-002\",\n", | ||
" input=text\n", | ||
" )[\"data\"][0][\"embedding\"]\n", | ||
"\n", | ||
"df[\"embedding\"] = df['text'].apply(embed)\n", | ||
"df[[\"text\", \"embedding\"]].to_csv(\"presidents_embeddings.csv\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\"\"\"\n", | ||
"We create a registry entry here in code. Notice we set number of retrieved documents k=2.\n", | ||
"\"\"\"\n", | ||
"\n", | ||
"registry_yaml = f\"\"\"\n", | ||
"retrieval/presidents/gpt-3.5-turbo:\n", | ||
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n", | ||
" args:\n", | ||
" completion_fn: gpt-3.5-turbo\n", | ||
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n", | ||
" k: 2\n", | ||
"\n", | ||
"retrieval/presidents/cot/gpt-3.5-turbo:\n", | ||
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n", | ||
" args:\n", | ||
" completion_fn: cot/gpt-3.5-turbo\n", | ||
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n", | ||
" k: 2\n", | ||
"\"\"\".strip()\n", | ||
"\n", | ||
"# Replace with path to your registry\n", | ||
"os.makedirs(\"completion_fns\", exist_ok=True)\n", | ||
"with open(\"completion_fns/retrieval.yaml\", \"w\") as f:\n", | ||
" f.write(registry_yaml)\n", | ||
"\n", | ||
"# GPT-3.5-turbo base: accuracy 0.7\n", | ||
"!oaieval gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n", | ||
"\n", | ||
"# GPT-3.5-turbo with retrieval: accuracy 0.9 -> The failure mode here is the retrieved president is incorrect: Andrew Johnson vs Andrew Jackson\n", | ||
"!oaieval retrieval/presidents/gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n", | ||
"\n", | ||
"# GPT-3.5-turbo with retrieval and chain-of-thought: accuracy 1.0\n", | ||
"!oaieval retrieval/presidents/cot/gpt-3.5-turbo born-first --max_samples 10 --registry_path ." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.9" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make the model a constructor parameter to make it more customizable.