Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Evals] Add retrieval completion fn and example #656

Merged
merged 3 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions evals/completion_fns/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Extending Completion Functions with Embeddings-based retrieval from a fetched dataset
"""
from ast import literal_eval
from typing import Any, Optional, Union

import numpy as np
import openai
import pandas as pd

from evals.api import CompletionFn, CompletionResult
from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt
from evals.record import record_sampling
from evals.registry import Registry


def load_embeddings(embeddings_and_text_path: str):
df = pd.read_csv(embeddings_and_text_path, converters={"embedding": literal_eval})
assert (
"text" in df.columns and "embedding" in df.columns
), "The embeddings file must have columns named 'text' and 'embedding'"
return df


def find_top_k_closest_embeddings(embedded_prompt: list[float], embs: list[list[float]], k: int):
# Normalize the embeddings
norm_embedded_prompt = embedded_prompt / np.linalg.norm(embedded_prompt)
norm_embs = embs / np.linalg.norm(embs, axis=1)[:, np.newaxis]

# Calculate cosine similarity
cosine_similarities = np.dot(norm_embs, norm_embedded_prompt)

# Get the indices of the top k closest embeddings
top_k_indices = np.argsort(cosine_similarities)[-k:]

return top_k_indices[::-1]


DEFAULT_RETRIEVAL_TEMPLATE = "Use the provided context to answer the question. "


class RetrievalCompletionResult(CompletionResult):
def __init__(self, response: str) -> None:
self.response = response

def get_completions(self) -> list[str]:
return [self.response.strip()]


class RetrievalCompletionFn(CompletionFn):
"""
This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion.
"""

def __init__(
self,
completion_fn: str,
embeddings_and_text_path: str,
retrieval_template: str = DEFAULT_RETRIEVAL_TEMPLATE,
k: int = 4,
registry: Optional[Registry] = None,
registry_path: Optional[str] = None,
**_kwargs: Any
) -> None:
"""
Args:
retrieval_template: The template to use for the retrieval. The task prompt will be added to the end of this template.
k: The number of docs to retrieve from the dataset.
completion_fn: The completion function to use for the retrieval.
embeddings_and_text_path: The path to a CSV containing "text" and "embedding" columns.
registry: Upstream callers may pass in a registry to use.
registry_path: The path to a registry file to add to default registry.
_kwargs: Additional arguments to pass to the completion function instantiation.
"""
registry = Registry() if not registry else registry
if registry_path:
registry.add_registry_paths(registry_path)

self.embeddings_df = load_embeddings(embeddings_and_text_path)

self.k = k

self.retrieval_template = retrieval_template
self.completion_fn_instance = registry.make_completion_fn(completion_fn)

def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCompletionResult:
"""
Args:
prompt: The prompt to complete, in either text string or Chat format.
kwargs: Additional arguments to pass to the completion function call method.
"""
# Embed the prompt
embedded_prompt = openai.Embedding.create(
model="text-embedding-ada-002", input=CompletionPrompt(prompt).to_formatted_prompt()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make the model a constructor parameter to make it more customizable.

)["data"][0]["embedding"]

embs = self.embeddings_df["embedding"].to_list()

# Compute the cosine similarity between the prompt and the embeddings
topk = " ".join(
self.embeddings_df.iloc[
find_top_k_closest_embeddings(embedded_prompt, embs, k=self.k)
].text.values
)

prompt = ChatCompletionPrompt(prompt).to_formatted_prompt()
retrieval_prompt = [{"role": "system", "content": self.retrieval_template + topk}] + prompt

answer = self.completion_fn_instance(prompt=retrieval_prompt, **kwargs).get_completions()[0]
record_sampling(prompt=retrieval_prompt, sampled=answer)
return RetrievalCompletionResult(answer)
113 changes: 113 additions & 0 deletions examples/retrieval-completionfn.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"We show here how to use the retrieval completion function to add context from documents to any OpenAI Evals task\n",
"The toy example here will be to augment our Born-First task with a dataset of presidential birthdays\n",
"\"\"\"\n",
"\n",
"# Download the dataset manually, or use curl\n",
"!curl -O https://people.math.sc.edu/Burkardt/datasets/presidents/president_birthdays.csv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import openai\n",
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"president_birthdays.csv\").rename(columns={\" \\\"Name\\\"\": \"Name\", \" \\\"Month\\\"\": \"Month\", \" \\\"Day\\\"\": \"Day\", \" \\\"Year\\\"\": \"Year\"}).set_index(\"Index\")\n",
"df[\"text\"] = df.apply(lambda r: f\"{r['Name']} was born on {r['Month']}/{r['Day']}/{r['Year']}\", axis=1)\n",
"display(df.head())\n",
"\n",
"def embed(text):\n",
" return openai.Embedding.create(\n",
" model=\"text-embedding-ada-002\",\n",
" input=text\n",
" )[\"data\"][0][\"embedding\"]\n",
"\n",
"df[\"embedding\"] = df['text'].apply(embed)\n",
"df[[\"text\", \"embedding\"]].to_csv(\"presidents_embeddings.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"We create a registry entry here in code. Notice we set number of retrieved documents k=2.\n",
"\"\"\"\n",
"\n",
"registry_yaml = f\"\"\"\n",
"retrieval/presidents/gpt-3.5-turbo:\n",
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n",
" args:\n",
" completion_fn: gpt-3.5-turbo\n",
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n",
" k: 2\n",
"\n",
"retrieval/presidents/cot/gpt-3.5-turbo:\n",
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n",
" args:\n",
" completion_fn: cot/gpt-3.5-turbo\n",
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n",
" k: 2\n",
"\"\"\".strip()\n",
"\n",
"# Replace with path to your registry\n",
"os.makedirs(\"completion_fns\", exist_ok=True)\n",
"with open(\"completion_fns/retrieval.yaml\", \"w\") as f:\n",
" f.write(registry_yaml)\n",
"\n",
"# GPT-3.5-turbo base: accuracy 0.7\n",
"!oaieval gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n",
"\n",
"# GPT-3.5-turbo with retrieval: accuracy 0.9 -> The failure mode here is the retrieved president is incorrect: Andrew Johnson vs Andrew Jackson\n",
"!oaieval retrieval/presidents/gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n",
"\n",
"# GPT-3.5-turbo with retrieval and chain-of-thought: accuracy 1.0\n",
"!oaieval retrieval/presidents/cot/gpt-3.5-turbo born-first --max_samples 10 --registry_path ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}