From c0dbe9d025bf2cf230a1c92ee30df4a6c8d9bf47 Mon Sep 17 00:00:00 2001 From: Andrew Kondrich <120423412+andrew-openai@users.noreply.github.com> Date: Wed, 12 Apr 2023 16:16:40 -0700 Subject: [PATCH] [Evals] Add retrieval completion fn and example (#656) --- evals/completion_fns/retrieval.py | 113 ++++++++++++++++++++++++++ examples/retrieval-completionfn.ipynb | 113 ++++++++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 evals/completion_fns/retrieval.py create mode 100644 examples/retrieval-completionfn.ipynb diff --git a/evals/completion_fns/retrieval.py b/evals/completion_fns/retrieval.py new file mode 100644 index 0000000000..40edc56464 --- /dev/null +++ b/evals/completion_fns/retrieval.py @@ -0,0 +1,113 @@ +""" +Extending Completion Functions with Embeddings-based retrieval from a fetched dataset +""" +from ast import literal_eval +from typing import Any, Optional, Union + +import numpy as np +import openai +import pandas as pd + +from evals.api import CompletionFn, CompletionResult +from evals.prompt.base import ChatCompletionPrompt, CompletionPrompt +from evals.record import record_sampling +from evals.registry import Registry + + +def load_embeddings(embeddings_and_text_path: str): + df = pd.read_csv(embeddings_and_text_path, converters={"embedding": literal_eval}) + assert ( + "text" in df.columns and "embedding" in df.columns + ), "The embeddings file must have columns named 'text' and 'embedding'" + return df + + +def find_top_k_closest_embeddings(embedded_prompt: list[float], embs: list[list[float]], k: int): + # Normalize the embeddings + norm_embedded_prompt = embedded_prompt / np.linalg.norm(embedded_prompt) + norm_embs = embs / np.linalg.norm(embs, axis=1)[:, np.newaxis] + + # Calculate cosine similarity + cosine_similarities = np.dot(norm_embs, norm_embedded_prompt) + + # Get the indices of the top k closest embeddings + top_k_indices = np.argsort(cosine_similarities)[-k:] + + return top_k_indices[::-1] + + +DEFAULT_RETRIEVAL_TEMPLATE = "Use the provided context to answer the question. " + + +class RetrievalCompletionResult(CompletionResult): + def __init__(self, response: str) -> None: + self.response = response + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + +class RetrievalCompletionFn(CompletionFn): + """ + This Completion Function uses embeddings to retrieve the top k relevant docs from a dataset to the prompt, then adds them to the context before calling the completion. + """ + + def __init__( + self, + completion_fn: str, + embeddings_and_text_path: str, + retrieval_template: str = DEFAULT_RETRIEVAL_TEMPLATE, + k: int = 4, + embedding_model: str = "text-embedding-ada-002", + registry: Optional[Registry] = None, + registry_path: Optional[str] = None, + **_kwargs: Any + ) -> None: + """ + Args: + retrieval_template: The template to use for the retrieval. The task prompt will be added to the end of this template. + k: The number of docs to retrieve from the dataset. + completion_fn: The completion function to use for the retrieval. + embeddings_and_text_path: The path to a CSV containing "text" and "embedding" columns. + registry: Upstream callers may pass in a registry to use. + registry_path: The path to a registry file to add to default registry. + _kwargs: Additional arguments to pass to the completion function instantiation. + """ + registry = Registry() if not registry else registry + if registry_path: + registry.add_registry_paths(registry_path) + + self.embeddings_df = load_embeddings(embeddings_and_text_path) + + self.embedding_model = embedding_model + self.k = k + + self.retrieval_template = retrieval_template + self.completion_fn_instance = registry.make_completion_fn(completion_fn) + + def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> RetrievalCompletionResult: + """ + Args: + prompt: The prompt to complete, in either text string or Chat format. + kwargs: Additional arguments to pass to the completion function call method. + """ + # Embed the prompt + embedded_prompt = openai.Embedding.create( + model=self.embedding_model, input=CompletionPrompt(prompt).to_formatted_prompt() + )["data"][0]["embedding"] + + embs = self.embeddings_df["embedding"].to_list() + + # Compute the cosine similarity between the prompt and the embeddings + topk = " ".join( + self.embeddings_df.iloc[ + find_top_k_closest_embeddings(embedded_prompt, embs, k=self.k) + ].text.values + ) + + prompt = ChatCompletionPrompt(prompt).to_formatted_prompt() + retrieval_prompt = [{"role": "system", "content": self.retrieval_template + topk}] + prompt + + answer = self.completion_fn_instance(prompt=retrieval_prompt, **kwargs).get_completions()[0] + record_sampling(prompt=retrieval_prompt, sampled=answer) + return RetrievalCompletionResult(answer) diff --git a/examples/retrieval-completionfn.ipynb b/examples/retrieval-completionfn.ipynb new file mode 100644 index 0000000000..2bdca581b5 --- /dev/null +++ b/examples/retrieval-completionfn.ipynb @@ -0,0 +1,113 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "We show here how to use the retrieval completion function to add context from documents to any OpenAI Evals task\n", + "The toy example here will be to augment our Born-First task with a dataset of presidential birthdays\n", + "\"\"\"\n", + "\n", + "# Download the dataset manually, or use curl\n", + "!curl -O https://people.math.sc.edu/Burkardt/datasets/presidents/president_birthdays.csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import openai\n", + "import pandas as pd\n", + "\n", + "df = pd.read_csv(\"president_birthdays.csv\").rename(columns={\" \\\"Name\\\"\": \"Name\", \" \\\"Month\\\"\": \"Month\", \" \\\"Day\\\"\": \"Day\", \" \\\"Year\\\"\": \"Year\"}).set_index(\"Index\")\n", + "df[\"text\"] = df.apply(lambda r: f\"{r['Name']} was born on {r['Month']}/{r['Day']}/{r['Year']}\", axis=1)\n", + "display(df.head())\n", + "\n", + "def embed(text):\n", + " return openai.Embedding.create(\n", + " model=\"text-embedding-ada-002\",\n", + " input=text\n", + " )[\"data\"][0][\"embedding\"]\n", + "\n", + "df[\"embedding\"] = df['text'].apply(embed)\n", + "df[[\"text\", \"embedding\"]].to_csv(\"presidents_embeddings.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "We create a registry entry here in code. Notice we set number of retrieved documents k=2.\n", + "\"\"\"\n", + "\n", + "registry_yaml = f\"\"\"\n", + "retrieval/presidents/gpt-3.5-turbo:\n", + " class: evals.completion_fns.retrieval:RetrievalCompletionFn\n", + " args:\n", + " completion_fn: gpt-3.5-turbo\n", + " embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n", + " k: 2\n", + "\n", + "retrieval/presidents/cot/gpt-3.5-turbo:\n", + " class: evals.completion_fns.retrieval:RetrievalCompletionFn\n", + " args:\n", + " completion_fn: cot/gpt-3.5-turbo\n", + " embeddings_and_text_path: {os.path.abspath('presidents_embeddings.csv')}\n", + " k: 2\n", + "\"\"\".strip()\n", + "\n", + "# Replace with path to your registry\n", + "os.makedirs(\"completion_fns\", exist_ok=True)\n", + "with open(\"completion_fns/retrieval.yaml\", \"w\") as f:\n", + " f.write(registry_yaml)\n", + "\n", + "# GPT-3.5-turbo base: accuracy 0.7\n", + "!oaieval gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n", + "\n", + "# GPT-3.5-turbo with retrieval: accuracy 0.9 -> The failure mode here is the retrieved president is incorrect: Andrew Johnson vs Andrew Jackson\n", + "!oaieval retrieval/presidents/gpt-3.5-turbo born-first --max_samples 10 --registry_path .\n", + "\n", + "# GPT-3.5-turbo with retrieval and chain-of-thought: accuracy 1.0\n", + "!oaieval retrieval/presidents/cot/gpt-3.5-turbo born-first --max_samples 10 --registry_path ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}