Skip to content

Commit

Permalink
[Evals] Add retrieval completion fn and example (openai#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-openai committed Apr 12, 2023
1 parent a17feae commit 1ccd9a9
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 0 deletions.
113 changes: 113 additions & 0 deletions evals/completion_fns/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Extending Completion Functions with Embeddings-based retrieval from a fetched dataset
"""
from ast import literal_eval
from typing import Any, Optional, Union

import numpy as np
import openai
import pandas as pd

from evals.api import CompletionFn, CompletionResult
from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt
from evals.record import record_sampling
from evals.registry import Registry


def load_embeddings(embeddings_and_text_path: str):
df = pd.read_csv(embeddings_and_text_path, converters={"embedding": literal_eval})
assert (
"text" in df.columns and "embedding" in df.columns
), "The embeddings file must have columns named 'text' and 'embedding'"
return df


def find_top_k_closest_embeddings(embedded_prompt: list[float], embs: list[list[float]], k: int):
# Normalize the embeddings
norm_embedded_prompt = embedded_prompt / np.linalg.norm(embedded_prompt)
norm_embs = embs / np.linalg.norm(embs, axis=1)[:, np.newaxis]

# Calculate cosine similarity
cosine_similarities = np.dot(norm_embs, norm_embedded_prompt)

# Get the indices of the top k closest embeddings
top_k_indices = np.argsort(cosine_similarities)[-k:]

return top_k_indices[::-1]


DEFAULT_RETRIEVAL_TEMPLATE = "Use the provided context to answer the question. "


class RetrievalCompletionResult(CompletionResult):
def __init__(self, response: str) -> None:
self.response = response

def get_completions(self) -> list[str]:
return [self.response.strip()]


class RetrievalCompletionFn(CompletionFn):
"""
This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion.
"""

def __init__(
self,
completion_fn: str,
embeddings_and_text_path: str,
retrieval_template: str = DEFAULT_RETRIEVAL_TEMPLATE,
k: int = 4,
embedding_model: str = "text-embedding-ada-002",
registry: Optional[Registry] = None,
registry_path: Optional[str] = None,
**_kwargs: Any
) -> None:
"""
Args:
retrieval_template: The template to use for the retrieval. The task prompt will be added to the end of this template.
k: The number of docs to retrieve from the dataset.
completion_fn: The completion function to use for the retrieval.
embeddings_and_text_path: The path to a CSV containing "text" and "embedding" columns.
registry: Upstream callers may pass in a registry to use.
registry_path: The path to a registry file to add to default registry.
_kwargs: Additional arguments to pass to the completion function instantiation.
"""
registry = Registry() if not registry else registry
if registry_path:
registry.add_registry_paths(registry_path)

self.embeddings_df = load_embeddings(embeddings_and_text_path)

self.embedding_model = embedding_model
self.k = k

self.retrieval_template = retrieval_template
self.completion_fn_instance = registry.make_completion_fn(completion_fn)

def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCompletionResult:
"""
Args:
prompt: The prompt to complete, in either text string or Chat format.
kwargs: Additional arguments to pass to the completion function call method.
"""
# Embed the prompt
embedded_prompt = openai.Embedding.create(
model=self.embedding_model, input=CompletionPrompt(prompt).to_formatted_prompt()
)["data"][0]["embedding"]

embs = self.embeddings_df["embedding"].to_list()

# Compute the cosine similarity between the prompt and the embeddings
topk = " ".join(
self.embeddings_df.iloc[
find_top_k_closest_embeddings(embedded_prompt, embs, k=self.k)
].text.values
)

prompt = ChatCompletionPrompt(prompt).to_formatted_prompt()
retrieval_prompt = [{"role": "system", "content": self.retrieval_template + topk}] + prompt

answer = self.completion_fn_instance(prompt=retrieval_prompt, **kwargs).get_completions()[0]
record_sampling(prompt=retrieval_prompt, sampled=answer)
return RetrievalCompletionResult(answer)
113 changes: 113 additions & 0 deletions examples/retrieval-completionfn.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"We show here how to use the retrieval completion function to add context from documents to any OpenAI Evals task\n",
"The toy example here will be to augment our Born-First task with a dataset of presidential birthdays\n",
"\"\"\"\n",
"\n",
"# Download the dataset manually, or use curl\n",
"!curl -O https://people.math.sc.edu/Burkardt/datasets/presidents/president_birthdays.csv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import openai\n",
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"president_birthdays.csv\").rename(columns={\" \\\"Name\\\"\": \"Name\", \" \\\"Month\\\"\": \"Month\", \" \\\"Day\\\"\": \"Day\", \" \\\"Year\\\"\": \"Year\"}).set_index(\"Index\")\n",
"df[\"text\"] = df.apply(lambda r: f\"{r['Name']} was born on {r['Month']}/{r['Day']}/{r['Year']}\", axis=1)\n",
"display(df.head())\n",
"\n",
"def embed(text):\n",
" return openai.Embedding.create(\n",
" model=\"text-embedding-ada-002\",\n",
" input=text\n",
" )[\"data\"][0][\"embedding\"]\n",
"\n",
"df[\"embedding\"] = df['text'].apply(embed)\n",
"df[[\"text\", \"embedding\"]].to_csv(\"presidents_embeddings.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"We create a registry entry here in code. Notice we set number of retrieved documents k=2.\n",
"\"\"\"\n",
"\n",
"registry_yaml = f\"\"\"\n",
"retrieval/presidents/gpt-3.5-turbo:\n",
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n",
" args:\n",
" completion_fn: gpt-3.5-turbo\n",
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n",
" k: 2\n",
"\n",
"retrieval/presidents/cot/gpt-3.5-turbo:\n",
" class: evals.completion_fns.retrieval:RetrievalCompletionFn\n",
" args:\n",
" completion_fn: cot/gpt-3.5-turbo\n",
" embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n",
" k: 2\n",
"\"\"\".strip()\n",
"\n",
"# Replace with path to your registry\n",
"os.makedirs(\"completion_fns\", exist_ok=True)\n",
"with open(\"completion_fns/retrieval.yaml\", \"w\") as f:\n",
" f.write(registry_yaml)\n",
"\n",
"# GPT-3.5-turbo base: accuracy 0.7\n",
"!oaieval gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n",
"\n",
"# GPT-3.5-turbo with retrieval: accuracy 0.9 -> The failure mode here is the retrieved president is incorrect: Andrew Johnson vs Andrew Jackson\n",
"!oaieval retrieval/presidents/gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n",
"\n",
"# GPT-3.5-turbo with retrieval and chain-of-thought: accuracy 1.0\n",
"!oaieval retrieval/presidents/cot/gpt-3.5-turbo born-first --max_samples 10 --registry_path ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 1ccd9a9

Please sign in to comment.