diff --git a/.github/workflows/test_eval.yaml b/.github/workflows/test_eval.yaml index 947d1ddabd..50c2526fcf 100644 --- a/.github/workflows/test_eval.yaml +++ b/.github/workflows/test_eval.yaml @@ -15,7 +15,7 @@ jobs: with: fetch-depth: 0 lfs: true - + - name: Install Git LFS run: | sudo apt-get install git-lfs @@ -47,8 +47,7 @@ jobs: echo "Processing $file" first_key=$(python .github/workflows/parse_yaml.py $file) echo "Eval Name: $first_key" - oaieval dummy-chat $first_key --max_samples 10 - oaieval dummy-completion $first_key --max_samples 10 + oaieval dummy $first_key --max_samples 10 done else echo "No new YAML files found in evals/registry/evals" diff --git a/MANIFEST.in b/MANIFEST.in index 188c3de7d1..a264d0a61e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ recursive-include evals *.py recursive-include evals *.yaml recursive-include evals *.sql -recursive-include evals *.jsonl +recursive-include evals/registry/data *.jsonl diff --git a/README.md b/README.md index 6044c94e66..4e08465efc 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,23 @@ # Evals -Evals is a framework for evaluating OpenAI models and an open-source registry of benchmarks. - -You can use Evals to create and run evaluations that: -- use datasets to generate prompts, -- measure the quality of completions provided by an OpenAI model, and -- compare performance across different datasets and models. - -With Evals, we aim to make it as simple as possible to build an eval while writing as little code as possible. To get started, we recommend that you follow these steps **in order**: -1. Read through this doc and follow the [setup instructions below](README.md#Setup). -2. Learn how to run existing evals: [run-evals.md](docs/run-evals.md). -3. Familiarize yourself with the existing eval templates: [eval-templates.md](docs/eval-templates.md). -4. Walk through the process for building an eval: [build-eval.md](docs/build-eval.md) -5. See an example of implementing custom eval logic: [custom-eval.md](docs/custom-eval.md). +Evals is a framework for evaluating LLMs (large language models) or systems built using LLMs as components. It also includes an open-source registry of challenging evals. + +We now support evaluating the behavior of any system including prompt chains or tool-using agents, via the [Completion Function Protocol](docs/completion-fns.md). + +With Evals, we aim to make it as simple as possible to build an eval while writing as little code as possible. An "eval" is a task used to evaluate the quality of a system's behavior. To get started, we recommend that you follow these steps: + +To get set up with evals, follow the [setup instructions below](README.md#Setup). + +#### Running evals +- Learn how to run existing evals: [run-evals.md](docs/run-evals.md). +- Familiarize yourself with the existing eval templates: [eval-templates.md](docs/eval-templates.md). + +#### Writing evals +- Walk through the process for building an eval: [build-eval.md](docs/build-eval.md) +- See an example of implementing custom eval logic: [custom-eval.md](docs/custom-eval.md). + +#### Writing CompletionFns +- Write your own completion functions: [completion-fns.md](docs/completion-fns.md) If you think you have an interesting eval, please open a PR with your contribution. OpenAI staff actively review these evals when considering improvements to upcoming models. diff --git a/docs/build-eval.md b/docs/build-eval.md index 240237e185..000254ee27 100644 --- a/docs/build-eval.md +++ b/docs/build-eval.md @@ -60,7 +60,7 @@ In general, running the same eval name against the same model should always give ## Running the eval -You can now run your eval on your data from the CLI with your choice of model: +You can now run your eval on your data from the CLI with your choice of model or completion function: ``` oaieval gpt-3.5-turbo ``` diff --git a/docs/completion-fn-protocol.md b/docs/completion-fn-protocol.md new file mode 100644 index 0000000000..2f9f6c0399 --- /dev/null +++ b/docs/completion-fn-protocol.md @@ -0,0 +1,41 @@ +### The Completion Function Protocol + +Here are the interfaces needed to implement the completion function protocol. Any implementation of this interface can be used inside `oaieval`. + +Reference implementations: +- [OpenAICompletionFn](../evals/completion_fns/openai.py) +- [LangChainLLMCompletionFn](../evals/completion_fns/langchain_llm.py) + +#### CompletionFn +Completion functions should implement the `CompletionFn` interface: +```python +class CompletionFn(Protocol): + def __call__( + self, + prompt: Union[str, list[dict[str, str]]], + **kwargs, + ) -> CompletionResult: +``` + +We take a `prompt` representing a single sample from an eval. These prompts can be represented as either a text string or a list of messages in [OpenAI Chat format](https://platform.openai.com/docs/guides/chat/introduction). To work with the existing evals, Completion Function implementations would need to handle both types of inputs, but we provide helper functionality to convert Chat formatted messages into a text string if that is the preferred input for your program: +```python +from evals.prompt.base import CompletionPrompt + +# chat_prompt: list[dict[str, str]] -> text_prompt: str +text_prompt = CompletionPrompt(chat_prompt).to_formatted_prompt() +``` + +#### CompletionResult +The completion function should return an object implementing the `CompletionResult` interface: +```python +class CompletionResult(ABC): + @abstractmethod + def get_completions(self) -> list[str]: + pass +``` +The `get_completions` method returns a list of string completions. Each element should be considered a unique completion (in most cases this will be a list of length 1). + +#### Using your CompletionFn +This is all that's needed to implement a Completion function that works with our existing Evals, allowing you to more easily evaluate your end-to-end logic on tasks. + +See [completion-fns.md](completion-fns.md) to see how to register and use your completion function with `oaieval`. diff --git a/docs/completion-fns.md b/docs/completion-fns.md new file mode 100644 index 0000000000..4d93d2be6a --- /dev/null +++ b/docs/completion-fns.md @@ -0,0 +1,49 @@ +# Completion Functions + +## What are completion functions +In [run-evals.md](run-evals.md), we learned how to make calls to `oaieval` to run an eval against a completion function. Completion Functions are generalizations of model completions, where a "completion" is some text output that would be our answer to the prompt. For example, if "Who played the girl elf in the hobbit?" is our prompt, the correct completion is "Evangeline Lilly". While we can just test a model directly to see if it generates "Evangeline Lilly", we can imagine doing numerous other operations under the hood to improve our ability to answer this question, like giving the model access to a browser to look up the answer before responding. Making it easy to implement this kind of under-the-hood operators before responding is the motivation behind building Completion Functions. + +## How to implement completion functions +A completion function needs to implement some interfaces that make it usable within Evals. At its core, it is just standardizing inputs to be a text string or [Chat conversation](https://platform.openai.com/docs/guides/chat), and the output to be a list of text strings. Implementing this interface will allow you to run your Completion Function against any eval in Evals. + +The exact interfaces needed are described in detail in [completion-fn-protocol.md](completion-fn-protocol.md) + +We include some example implementations inside `evals/completion_fns`. For example, the [`LangChainLLMCompletionFn`](../evals/completion_fns/langchain_llm.py) implements a way to generate completions from [LangChain LLMs](https://python.langchain.com/en/latest/modules/models/llms/getting_started.html). We can then use these completion functions with `oaieval`: +``` +oaieval langchain/llm/flan-t5-xl test-match +``` + +## Registering Completion Functions +Once you have written a completion function, we need to make the class visible to the `oaieval` CLI. Similar to how we register our evals, we also register Completion Functions inside `evals/registry/completion_fns` as `yaml` files. Here is the registration for our langchain LLM completion function: +```yaml +langchain/llm/flan-t5-xl: + class: evals.completion_fns.langchain_llm:LangChainLLMCompletionFn + args: + llm: HuggingFaceHub + llm_kwargs: + repo_id: google/flan-t5-xl +``` +Here is how it breaks down +`langchain/llm/flan-t5-xl`: This is the top level key that will be used to access this completion function with `oaieval`. +`class`: This is the path to your implementation of the completion function protocol. This class needs to importable within your python environment. +`args`: These are arguments that are passed to your completion function when it is instantiated. + + +### Developing Completion Functions outside of Evals +It is possible to register CompletionFunctions without directly modifying the registry or code inside `Evals` by using the `--registry_path` argument. As an example, let's say I want to use `MyCompletionFn` located inside `~/my_project/`: +``` +my_project +├── my_completion_fn.py +└── completion_fns + └── my_completion_fn.yaml +``` + +If `my_project` is importable within the python environment (accessible via PYTHONPATH), we can structure `my_completion_fn.yaml` as: +``` +my_completion_fn: + class: my_project.my_completion_fn:MyCompletionFn +``` +Then, we can make calls to `oaieval` using: +``` +oaieval my_completion_fn test-match --registry_path ~/my_project +``` diff --git a/docs/run-evals.md b/docs/run-evals.md index 958b466b89..0cc1e3943f 100644 --- a/docs/run-evals.md +++ b/docs/run-evals.md @@ -4,12 +4,15 @@ We provide two command line interfaces (CLIs): `oaieval` for running a single ev ## Running an eval -When using the `oaieval` command, you will need to provide both the model you wish to evaluate as well as the eval to run. E.g., +When using the `oaieval` command, you will need to provide the completion function you wish to evaluate as well as the eval to run. E.g., ```sh oaieval gpt-3.5-turbo test-match ``` +The valid eval names are specified in the YAML files under `evals/registry/evals` and their corresponding implementations can be found in `evals/elsuite`. -In this example, `gpt-3.5-turbo` is the model to evaluate, and `test-match` is the eval to run. The valid model names are those which you have access to via the API. The valid eval names are specified in the YAML files under `evals/registry/evals`, and their corresponding implementations can be found in `evals/elsuite`. +In this example, `gpt-3.5-turbo` is an OpenAI model that we dynamically instantiate as a completion function using `OpenAIChatCompletionFn(model=gpt-3.5-turbo)`. Any implementation of the `CompletionFn` protocol can be run against `oaieval`. By default, we support calling `oaieval` with any model available on the OpenAI API or with CompletionFunctions available in [`evals/registry/completion_fns`](../evals/registry/completion_fns/). We are always interested in adding more completion functions and we encourage you to implement you own to reflect specific use cases. + +More details on `CompletionFn` found here: [`completion-fns.md`](completion-fns.md) These CLIs can accept various flags to modify their default behavior. For example: - If you wish to log to a Snowflake database (which you have already set up as described in the [README](../README.md)), add `--no-local-run`. diff --git a/evals/__init__.py b/evals/__init__.py index f21e608720..e57d24a9a8 100644 --- a/evals/__init__.py +++ b/evals/__init__.py @@ -1,4 +1,8 @@ -from .api import check_sampled_text, completion_query, sample_freeform -from .base import ModelSpec, ModelSpecs +from .api import CompletionFn, CompletionResult, DummyCompletionFn, record_and_check_match +from .completion_fns.openai import ( + OpenAIChatCompletionFn, + OpenAICompletionFn, + OpenAICompletionResult, +) from .data import get_csv, get_json, get_jsonl, get_jsonls, get_lines, iter_jsonls from .eval import Eval diff --git a/evals/api.py b/evals/api.py index 51dc2291f3..71e16e3ab3 100644 --- a/evals/api.py +++ b/evals/api.py @@ -4,131 +4,73 @@ """ import logging -from typing import Callable, Optional, Union +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable -from evals.base import ModelSpec -from evals.prompt.base import ( - ChatCompletionPrompt, - CompletionPrompt, - OpenAICreateChatPrompt, - OpenAICreatePrompt, - Prompt, -) -from evals.record import record_match, record_sampling -from evals.utils.api_utils import ( - openai_chat_completion_create_retrying, - openai_completion_create_retrying, -) +from evals.prompt.base import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt +from evals.record import record_match logger = logging.getLogger(__name__) -def completion_query( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - **kwargs, -) -> tuple[dict, Union[OpenAICreatePrompt, OpenAICreateChatPrompt], dict]: - """ - Query the API for a completion. - - ARGS - ==== - `model_spec`: `ModelSpec` containing model details to use in the query. - This should be the dict returned by `registry.get_model()`. - If `model_spec` is not provided, we use the default model that was - intialized at the beginning of the run. - `prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in - the approriate `Prompt` class. - `kwargs`: Other arguments passed to the API. - - RETURNS - ======= - The result of the API call. - The prompt that was fed into the API call as a str. - A dict containing metadata about the query. - """ - if not isinstance(prompt, Prompt): - assert ( - isinstance(prompt, str) - or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt)) - or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt)) - or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) - ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" +class CompletionResult(ABC): + @abstractmethod + def get_completions(self) -> list[str]: + pass - if model_spec.is_chat: - prompt = ChatCompletionPrompt( - raw_prompt=prompt, - ) - else: - prompt = CompletionPrompt( - raw_prompt=prompt, - ) - openai_create_prompt: Union[ - OpenAICreatePrompt, OpenAICreateChatPrompt - ] = prompt.to_openai_create_prompt() +@runtime_checkable +class CompletionFn(Protocol): + def __call__( + self, + prompt: Union[str, OpenAICreateChatPrompt], + **kwargs, + ) -> CompletionResult: + """ + ARGS + ==== + `prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in + the approriate `Prompt` class. + `kwargs`: Other arguments passed to the API. - extra_args = { - key: model_spec.extra_options.get(key, kwargs.get(key)) - for key in set(kwargs) | set(model_spec.extra_options) - } + RETURNS + ======= + The result of the API call. + The prompt that was fed into the API call as a str. + """ - if model_spec.is_chat: - result = openai_chat_completion_create_retrying( - model=model_spec.model, - engine=model_spec.engine, - api_base=model_spec.api_base, - api_key=model_spec.api_key, - messages=openai_create_prompt, - **extra_args, - ) - else: - result = openai_completion_create_retrying( - model=model_spec.model, - engine=model_spec.engine, - api_base=model_spec.api_base, - api_key=model_spec.api_key, - prompt=openai_create_prompt, - **extra_args, - ) - metadata = {} - if result: - metadata["completion_id"] = result.get("id", None) - metadata["model"] = result.get("model", None) +class DummyCompletionResult(CompletionResult): + def get_completions(self) -> list[str]: + return ["This is a dummy response."] - if model_spec.is_chat: - for choice in result["choices"]: - choice["text"] = choice["message"]["content"] - return result, openai_create_prompt, metadata +class DummyCompletionFn(CompletionFn): + def __call__( + self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], **kwargs + ) -> CompletionResult: + return DummyCompletionResult() -def check_sampled_text( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], +def record_and_check_match( + prompt: Any, + sampled: str, expected: Union[str, list[str], tuple[str]], - *, - options: Optional[list[str]] = None, separator: Callable[[str], bool] = None, -) -> Optional[str]: + options: Optional[list[str]] = None, +): """ - Generates a completion using the prompt, checks whether the completion is - one of the expected completions, and then records the result. + Records and checks if a sampled response from a CompletionFn matches the expected result. - ARGS - ==== - `model_spec`: See `completion_query`. - `prompt`: See `completion_query`. - `options`: The list of canonical options, defaults to `expected` if None. - The completion will be converted to one of these options. - `expected`: The desired completion or the list of desired completions. - `separator`: A callable which check the character sampled after the option - to see if it is a valid separator. + Args: + prompt: The input prompt. + sampled: The sampled response from the model. + expected: The expected response or list of responses. + separator: Optional function to check if a character is a separator. + options: Optional list of options to match against the sampled response. - RETURNS - ======= - The option that was picked, i.e., matched the completion, or None. + Returns: + The matched option or None if no match found. """ if isinstance(expected, tuple): expected = list(expected) @@ -137,15 +79,6 @@ def check_sampled_text( if options is None: options = expected - result, actual_prompt, metadata = completion_query( - prompt=prompt, - temperature=0.0, - model_spec=model_spec, - ) - choice = result["choices"][0] - - sampled = choice["text"].strip() if model_spec.strip_completion else choice["text"] - picked = None for option in options: if not sampled.startswith(option): @@ -160,7 +93,7 @@ def check_sampled_text( break result = { - "prompt": actual_prompt, + "prompt": prompt, "sampled": sampled, "options": options, "picked": picked, @@ -168,90 +101,5 @@ def check_sampled_text( match = picked in expected result["expected"] = expected result["match"] = match - result["metadata"] = metadata - record_sampling(**result) - record_match(match, expected=expected, picked=picked, sampled=sampled) + record_match(match, expected=expected, picked=picked, sampled=sampled, options=options) return picked - - -def sample_freeform( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - *, - temperature: float = 1.0, - top_p: float = 0.9, - max_tokens: int = 512, - stop: Optional[str] = None, - n_samples: int = None, - return_logprobs: bool = False, - **kwargs, -) -> Union[str, list[str], dict]: - """ - Samples a freeform response from the specified model, records the sampling, - and returns the sampled text. - - ARGS - ==== - `model_spec`: See `completion_query`. - `prompt`: See `completion_query`. - `temperature`: Passed to `openai.Completion.create`. - `top_p`: Passed to `openai.Completion.create`. - `max_tokens`: Passed to `openai.Completion.create`. - `stop`: Passed to `openai.Completion.create`. - `n_samples`: The number of samples to generate (1 if None). - `return_logprobs`: If True, returns the tokens and corresponding logprobs - in addition to the sampled text. - `kwargs`: See `completion_query`. - - RETURNS - ======= - If `return_logprobs` is True, returns a dict with the sampled text, tokens, - and corresponding logprobs. If `n_samples` is None, the outer list is - removed from all values. - Otherwise, returns the sampled text, or a list of sampled texts if - `n_samples` is not None. - """ - response, actual_prompt, metadata = completion_query( - prompt=prompt, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - stop=stop, - n=(1 if n_samples is None else n_samples), - model_spec=model_spec, - headers={}, - **kwargs, - ) - sampled = [choice["text"] for choice in response["choices"]] - if n_samples is None: - sampled = sampled[0] - record_sampling(prompt=actual_prompt, sampled=sampled, metadata=metadata) - - if return_logprobs: - assert not model_spec.is_chat, "logprobs only works for non-chat models" - assert not kwargs.get("logprobs") is None - - def _maybe_tokens(logprobs: Optional[dict]) -> Optional[list[str]]: - return logprobs["tokens"] if logprobs is not None else None - - def _maybe_logprobs(logprobs: Optional[dict]) -> Optional[list[float]]: - return logprobs["token_logprobs"] if logprobs is not None else None - - def _maybe_top_logprobs(logprobs: Optional[dict]) -> Optional[list[dict[str, float]]]: - return [dict(x) for x in logprobs["top_logprobs"]] if logprobs is not None else None - - tokens = [_maybe_tokens(choice["logprobs"]) for choice in response["choices"]] - logprobs = [_maybe_logprobs(choice["logprobs"]) for choice in response["choices"]] - top_logprobs = [_maybe_top_logprobs(choice["logprobs"]) for choice in response["choices"]] - if n_samples is None: - tokens = tokens[0] - logprobs = logprobs[0] - top_logprobs = top_logprobs[0] - return { - "text": sampled, - "tokens": tokens, - "logprobs": logprobs, - "top_logprobs": top_logprobs, - } - - return sampled diff --git a/evals/base.py b/evals/base.py index ae55185c6f..1a5c34cfb1 100644 --- a/evals/base.py +++ b/evals/base.py @@ -5,45 +5,23 @@ import base64 import datetime import os -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence if TYPE_CHECKING: from dataclasses import dataclass else: from pydantic.dataclasses import dataclass - @dataclass -class ModelSpec: +class CompletionFnSpec: """ - Specification for a model. + Specification for a CompletionFn. """ - - name: str - model: Optional[str] = None - engine: Optional[str] = None - api_base: Optional[str] = None - - is_chat: bool = False - - encoding: Optional[str] = None - organization: Optional[str] = None - api_key: Optional[str] = None - extra_options: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - headers: Optional[Mapping[str, Any]] = None - strip_completion: bool = True - n_ctx: Optional[int] = None - format: Optional[str] = None + cls: str + args: Optional[Dict[str, Any]] = None key: Optional[str] = None group: Optional[str] = None - def __post_init__(self): - if self.extra_options is None: - self.extra_options = {} - if self.headers is None: - self.headers = {} - @dataclass class BaseEvalSpec: @@ -89,52 +67,9 @@ class EvalSetSpec: group: Optional[str] = None -@dataclass -class ModelSpecs: - completions_: Optional[Sequence[ModelSpec]] = None - embedding_: Optional[ModelSpec] = None - ranking_: Optional[ModelSpec] = None - - @property - def embedding(self) -> ModelSpec: - if self.embedding_ is None: - raise ValueError("Embedding model was not specified") - return self.embedding_ - - @property - def ranking(self) -> ModelSpec: - if self.ranking_ is None: - raise ValueError("Ranking model was not specified") - return self.ranking_ - - @property - def completion(self) -> ModelSpec: - if self.completions_ is None: - raise ValueError("Completion model was not specified") - return self.completions_[0] - - @property - def completions(self) -> Sequence[ModelSpec]: - if self.completions_ is None: - raise ValueError("Completion model was not specified") - return self.completions_ - - @property - def names(self) -> dict[str, Sequence[str]]: - dict = {} - if self.completions_ is not None: - dict["completions"] = [model.name for model in self.completions_] - if self.embedding_ is not None: - dict["embedding"] = [self.embedding_.name] - if self.ranking_ is not None: - dict["ranking"] = [self.ranking_.name] - return dict - - @dataclass class RunSpec: - model_name: str - model_names: dict[str, Sequence[str]] + completion_fns: list[str] eval_name: str base_eval: str split: str diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index 8c47217f52..a72b194ae0 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -5,7 +5,6 @@ import logging import shlex import sys -from functools import cached_property from typing import Any, Mapping, Optional import openai @@ -14,7 +13,6 @@ import evals.api import evals.base import evals.record -from evals.base import ModelSpec, ModelSpecs from evals.registry import Registry logger = logging.getLogger(__name__) @@ -26,12 +24,13 @@ def _purple(str): def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run evals through the API") - parser.add_argument("model", type=str, help="Name of a completion model.") + parser.add_argument( + "completion_fn", + type=str, + help="One or more CompletionFn URLs, separated by commas (,). A CompletionFn can either be the name of a model available in the OpenAI API or a key in the registry (see evals/registry/completion_fns).", + ) parser.add_argument("eval", type=str, help="Name of an eval. See registry.") - parser.add_argument("--embedding_model", type=str, default="") - parser.add_argument("--ranking_model", type=str, default="") parser.add_argument("--extra_eval_params", type=str, default="") - parser.add_argument("--modelspec_extra_options", type=str, default="") parser.add_argument("--max_samples", type=int, default=None) parser.add_argument("--cache", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--visible", action=argparse.BooleanOptionalAction, default=None) @@ -41,6 +40,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument( "--log_to_file", type=str, default=None, help="Log to a file instead of stdout" ) + parser.add_argument( + "--registry_path", type=str, default=None, action="append", help="Path to the registry" + ) parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--local-run", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--dry-run", action=argparse.BooleanOptionalAction, default=False) @@ -48,102 +50,7 @@ def get_parser() -> argparse.ArgumentParser: return parser -def parse_extra_eval_params(param_str: Optional[str]) -> Mapping[str, Any]: - """Parse a string of the form "key1=value1,key2=value2" into a dict.""" - if not param_str: - return {} - - def to_number(x): - try: - return int(x) - except: - pass - try: - return float(x) - except: - pass - return x - - str_dict = dict(kv.split("=") for kv in param_str.split(",")) - return {k: to_number(v) for k, v in str_dict.items()} - - -def n_ctx_from_model_name(model_name: str) -> Optional[int]: - """Returns n_ctx for a given API model name. Model list last updated 2023-03-14.""" - # note that for most models, the max tokens is n_ctx + 1 - DICT_OF_N_CTX_BY_MODEL_NAME_PREFIX: dict[str, int] = { - "dummy-": 2048, - "gpt-3.5-turbo-": 4096, - "gpt-4-": 8192, - "gpt-4-32k-": 32768, - } - DICT_OF_N_CTX_BY_MODEL_NAME: dict[str, int] = { - "ada": 2048, - "text-ada-001": 2048, - "babbage": 2048, - "text-babbage-001": 2048, - "curie": 2048, - "text-curie-001": 2048, - "davinci": 2048, - "text-davinci-001": 2048, - "code-davinci-002": 8000, - "text-davinci-002": 4096, - "text-davinci-003": 4096, - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-0301": 4096, - "gpt-4": 8192, - "gpt-4-0314": 8192, - "gpt-4-32k": 32768, - "gpt-4-32k-0314": 32768, - } - # first, look for a prefix match - for model_prefix, n_ctx in DICT_OF_N_CTX_BY_MODEL_NAME_PREFIX.items(): - if model_name.startswith(model_prefix): - return n_ctx - # otherwise, look for an exact match and return None if not found - return DICT_OF_N_CTX_BY_MODEL_NAME.get(model_name, None) - - -class ModelResolver: - # This is a temporary method to identify which models are chat models. - # Eventually, the OpenAI API should expose this information directly. - CHAT_MODELS = { - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", - "gpt-4", - "gpt-4-0314", - "gpt-4-32k", - "gpt-4-32k-0314", - "dummy-chat", - } - - DUMMY_MODELS = { - "dummy-chat", - "dummy-completion", - } - - def resolve(self, name: str) -> ModelSpec: - if name in self.DUMMY_MODELS: - result = ModelSpec(name=name, model=name, is_chat=(name in self.CHAT_MODELS)) - return result - - if name in self.api_model_ids: - result = ModelSpec( - name=name, - model=name, - is_chat=(name in self.CHAT_MODELS), - n_ctx=n_ctx_from_model_name(name), - ) - return result - - raise ValueError(f"Couldn't find model: {name}") - - @cached_property - def api_model_ids(self): - return [m["id"] for m in openai.Model.list()["data"]] - - -def run(args, model_resolver: ModelResolver, registry: Optional[Registry] = None): +def run(args, registry: Optional[Registry] = None): if args.debug: logging.getLogger().setLevel(logging.DEBUG) @@ -153,27 +60,19 @@ def run(args, model_resolver: ModelResolver, registry: Optional[Registry] = None evals.eval.set_max_samples(args.max_samples) registry = registry or Registry() + if args.registry_path: + registry.add_registry_paths(args.registry_path) + eval_spec = registry.get_eval(args.eval) assert ( eval_spec is not None ), f"Eval {args.eval} not found. Available: {list(sorted(registry._evals.keys()))}" - def get_model(name: str) -> ModelSpec: - return model_resolver.resolve(name) - - completion_model_specs = [get_model(model) for model in args.model.split(",")] - - for spec in completion_model_specs: - spec.extra_options = parse_extra_eval_params(args.modelspec_extra_options) - - model_specs = ModelSpecs( - completions_=completion_model_specs, - embedding_=get_model(args.embedding_model) if args.embedding_model else None, - ranking_=get_model(args.ranking_model) if args.ranking_model else None, - ) + completion_fns = args.completion_fn.split(",") + completion_fn_instances = [registry.make_completion_fn(url) for url in completion_fns] run_config = { - "model_specs": model_specs, + "completion_fns": completion_fns, "eval_spec": eval_spec, "seed": args.seed, "max_samples": args.max_samples, @@ -183,11 +82,9 @@ def get_model(name: str) -> ModelSpec: }, } - model_name = model_specs.completions_[0].name if len(model_specs.completions_) > 0 else "n/a" eval_name = eval_spec.key run_spec = evals.base.RunSpec( - model_name=model_name, - model_names=model_specs.names, + completion_fns=completion_fns, eval_name=eval_name, base_eval=eval_name.split(".")[0], split=eval_name.split(".")[1], @@ -195,7 +92,7 @@ def get_model(name: str) -> ModelSpec: created_by=args.user, ) if args.record_path is None: - record_path = f"/tmp/evallogs/{run_spec.run_id}_{args.model}_{args.eval}.jsonl" + record_path = f"/tmp/evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl" else: record_path = args.record_path if args.dry_run: @@ -212,11 +109,30 @@ def get_model(name: str) -> ModelSpec: run_url = f"{run_spec.run_id}" logger.info(_purple(f"Run started: {run_url}")) + def parse_extra_eval_params(param_str: Optional[str]) -> Mapping[str, Any]: + """Parse a string of the form "key1=value1,key2=value2" into a dict.""" + if not param_str: + return {} + + def to_number(x): + try: + return int(x) + except: + pass + try: + return float(x) + except: + pass + return x + + str_dict = dict(kv.split("=") for kv in param_str.split(",")) + return {k: to_number(v) for k, v in str_dict.items()} + extra_eval_params = parse_extra_eval_params(args.extra_eval_params) eval_class = registry.get_class(eval_spec) eval = eval_class( - model_specs=model_specs, + completion_fns=completion_fn_instances, seed=args.seed, name=eval_name, registry=registry, @@ -245,7 +161,7 @@ def main(): logging.getLogger("openai").setLevel(logging.WARN) if hasattr(openai.error, "set_display_cause"): openai.error.set_display_cause() - run(args, model_resolver=ModelResolver()) + run(args) if __name__ == "__main__": diff --git a/evals/completion_fns/__init__.py b/evals/completion_fns/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/evals/completion_fns/cot.py b/evals/completion_fns/cot.py new file mode 100644 index 0000000000..dbf8cbfcc0 --- /dev/null +++ b/evals/completion_fns/cot.py @@ -0,0 +1,65 @@ +""" +Extending Completion Functions with Chain-of-Thought +""" +from evals.prompt.base import ChatCompletionPrompt +from evals.record import record_sampling +from evals.registry import Registry + +DEFAULT_COT_TEMPLATE = "\nBefore answering, reason in a step-by-step manner as to get the right answer, then conclude with the answer." +DEFAULT_EXTRACT_ANSWER_TEMPLATE = ( + "\nGiven the above reasoning, the answer in the format requested by the question is:" +) + + +class ChainOfThoughtCompletionResult: + def __init__(self, response) -> None: + self.response = response + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + +class ChainOfThoughtCompletionFn: + def __init__( + self, + cot_template: str = DEFAULT_COT_TEMPLATE, + extract_answer_template: str = DEFAULT_EXTRACT_ANSWER_TEMPLATE, + cot_completion_fn: str = None, + extract_completion_fn: str = None, + registry: Registry = None, + registry_path: str = None, + **kwargs + ) -> None: + registry = Registry() if not registry else registry + if registry_path: + registry.add_registry_paths(registry_path) + + if extract_completion_fn is None: + extract_completion_fn = cot_completion_fn + + # This model will use chain of thought to answer the question + self.cot_template = cot_template + self.cot_completion_fn_instance = registry.make_completion_fn(cot_completion_fn) + + # This model will extract the answer from the chain of thought + self.extract_answer_template = extract_answer_template + self.extract_completion_fn_instance = registry.make_completion_fn(extract_completion_fn) + + def __call__(self, prompt, **kwargs) -> ChainOfThoughtCompletionResult: + # Ensure it is in string format + prompt = ChatCompletionPrompt(prompt).to_formatted_prompt() + + cot_prompt = prompt + [{"role": "assistant", "content": self.cot_template}] + answer = self.cot_completion_fn_instance(prompt=cot_prompt, **kwargs).get_completions()[0] + record_sampling(prompt=cot_prompt, sampled=answer) + + extraction_prompt = cot_prompt + [ + {"role": "assistant", "content": answer}, + {"role": "assistant", "content": self.extract_answer_template}, + ] + extracted_answer = self.extract_completion_fn_instance( + prompt=extraction_prompt, **kwargs + ).get_completions()[0] + record_sampling(prompt=extraction_prompt, sampled=extracted_answer) + + return ChainOfThoughtCompletionResult(extracted_answer) diff --git a/evals/completion_fns/langchain_llm.py b/evals/completion_fns/langchain_llm.py new file mode 100644 index 0000000000..1b3f020f70 --- /dev/null +++ b/evals/completion_fns/langchain_llm.py @@ -0,0 +1,33 @@ +import importlib +from typing import Optional + +from langchain.llms import BaseLLM + +from evals.prompt.base import CompletionPrompt +from evals.record import record_sampling + + +class LangChainLLMCompletionResult: + def __init__(self, response) -> None: + self.response = response + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + +class LangChainLLMCompletionFn: + def __init__(self, llm: str, llm_kwargs: Optional[dict] = {}, **kwargs) -> None: + # Import and resolve self.llm to an instance of llm argument here, assuming it's always a subclass of BaseLLM + module = importlib.import_module("langchain.llms") + LLMClass = getattr(module, llm) + + if issubclass(LLMClass, BaseLLM): + self.llm = LLMClass(**llm_kwargs) + else: + raise ValueError(f"{llm} is not a subclass of BaseLLM") + + def __call__(self, prompt, **kwargs) -> LangChainLLMCompletionResult: + prompt = CompletionPrompt(prompt).to_formatted_prompt() + response = self.llm(prompt) + record_sampling(prompt=prompt, sampled=response) + return LangChainLLMCompletionResult(response) diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py new file mode 100644 index 0000000000..4e09cca581 --- /dev/null +++ b/evals/completion_fns/openai.py @@ -0,0 +1,137 @@ +from typing import Any, Optional, Union + +from evals.prompt.base import ( + ChatCompletionPrompt, + CompletionPrompt, + OpenAICreateChatPrompt, + OpenAICreatePrompt, + Prompt, +) +from evals.record import record_sampling +from evals.utils.api_utils import ( + openai_chat_completion_create_retrying, + openai_completion_create_retrying, +) + + +class OpenAIBaseCompletionResult: + def __init__(self, raw_data: Any, prompt: Any): + self.raw_data = raw_data + self.prompt = prompt + + def get_completions(self) -> list[str]: + raise NotImplementedError + + +class OpenAIChatCompletionResult(OpenAIBaseCompletionResult): + def get_completions(self) -> list[str]: + completions = [] + if self.raw_data and "choices" in self.raw_data: + for choice in self.raw_data["choices"]: + if "message" in choice: + completions.append(choice["message"]["content"]) + return completions + + +class OpenAICompletionResult(OpenAIBaseCompletionResult): + def get_completions(self) -> list[str]: + completions = [] + if self.raw_data and "choices" in self.raw_data: + for choice in self.raw_data["choices"]: + if "text" in choice: + completions.append(choice["text"]) + return completions + + +class OpenAICompletionFn: + def __init__( + self, + model: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + extra_options: Optional[dict] = {}, + **kwargs, + ): + self.model = model + self.api_base = api_base + self.api_key = api_key + self.n_ctx = n_ctx + self.extra_options = extra_options + + def __call__( + self, + prompt: Union[str, OpenAICreateChatPrompt], + **kwargs, + ) -> OpenAICompletionResult: + if not isinstance(prompt, Prompt): + assert ( + isinstance(prompt, str) + or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) + ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" + + prompt = CompletionPrompt( + raw_prompt=prompt, + ) + + openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt() + + result = openai_completion_create_retrying( + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + prompt=openai_create_prompt, + **{**kwargs, **self.extra_options}, + ) + result = OpenAICompletionResult(raw_data=result, prompt=openai_create_prompt) + record_sampling(prompt=result.prompt, sampled=result.get_completions()) + return result + + +class OpenAIChatCompletionFn: + def __init__( + self, + model: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + n_ctx: Optional[int] = None, + extra_options: Optional[dict] = {}, + **kwargs, + ): + self.model = model + self.api_base = api_base + self.api_key = api_key + self.n_ctx = n_ctx + self.extra_options = extra_options + + def __call__( + self, + prompt: Union[str, OpenAICreateChatPrompt], + **kwargs, + ) -> OpenAIChatCompletionResult: + if not isinstance(prompt, Prompt): + assert ( + isinstance(prompt, str) + or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) + ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" + + prompt = ChatCompletionPrompt( + raw_prompt=prompt, + ) + + openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt() + + result = openai_chat_completion_create_retrying( + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + messages=openai_create_prompt, + **{**kwargs, **self.extra_options}, + ) + result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt) + record_sampling(prompt=result.prompt, sampled=result.get_completions()) + return result diff --git a/evals/elsuite/basic/fuzzy_match.py b/evals/elsuite/basic/fuzzy_match.py index fdee1092a6..e9d7aab655 100644 --- a/evals/elsuite/basic/fuzzy_match.py +++ b/evals/elsuite/basic/fuzzy_match.py @@ -1,5 +1,7 @@ -import evals import numpy as np + +import evals +from evals.api import CompletionFn from evals.elsuite import utils from evals.record import RecorderBase @@ -7,40 +9,40 @@ class FuzzyMatch(evals.Eval): def __init__( self, - model_specs: evals.ModelSpecs, + completion_fns: list[CompletionFn], samples_jsonl: str, *args, max_tokens: int = 500, **kwargs, ): - super().__init__(model_specs, *args, **kwargs) + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "FuzzyMatch only supports one completion fn" self.max_tokens = max_tokens self.samples_jsonl = samples_jsonl def eval_sample(self, test_sample, rng): + del rng prompt, correct_answers = test_sample["input"], test_sample["ideal"] - generated_answer = evals.sample_freeform( - self.model_spec, - prompt, - temperature=0.0, + result = self.completion_fn( + prompt=prompt, + temperature=0.0, # Q: why are these hardcoded? max_tokens=16, ) - matches = [ - utils.fuzzy_match(generated_answer, correct_answer) - for correct_answer in correct_answers - ] + sampled = result.get_completions()[0] + + matches = [utils.fuzzy_match(sampled, correct_answer) for correct_answer in correct_answers] evals.record.record_match( True in matches, expected=correct_answers, - picked=[generated_answer for i in range(len(correct_answers)) if matches[i]], + picked=[sampled for i in range(len(correct_answers)) if matches[i]], ) evals.record.record_metrics( accuracy=float(True in matches), - f1_score=utils.f1_score(generated_answer, correct_answers), + f1_score=utils.f1_score(sampled, correct_answers), ) def run(self, recorder: RecorderBase): - samples = evals.get_jsonl(self.samples_jsonl) + samples = self.get_samples() self.eval_all_samples(recorder, samples) return { diff --git a/evals/elsuite/basic/includes.py b/evals/elsuite/basic/includes.py index af16600628..487c0e2534 100644 --- a/evals/elsuite/basic/includes.py +++ b/evals/elsuite/basic/includes.py @@ -1,36 +1,38 @@ from typing import Any +import numpy as np + import evals -import evals.elsuite.utils import evals.metrics -import numpy as np +from evals.api import CompletionFn +from evals.elsuite import utils class Includes(evals.Eval): def __init__( self, - model_specs: evals.ModelSpecs, + completion_fns: list[CompletionFn], samples_jsonl: str, *args, - max_tokens: int = 500, **kwargs, ): - super().__init__(model_specs, *args, **kwargs) - self.max_tokens = max_tokens + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "Includes only supports one completion fn" self.samples_jsonl = samples_jsonl def eval_sample(self, sample: Any, *_): - sampled = evals.sample_freeform( - self.model_spec, sample["input"], max_tokens=self.max_tokens - ) - includes_answer = any( - [evals.elsuite.utils.get_answer(sampled, ref) for ref in sample["ideal"]] + prompt = sample["input"] + result = self.completion_fn( + prompt=prompt, ) + sampled = result.get_completions()[0] + + includes_answer = any([utils.get_answer(sampled, ref) for ref in sample["ideal"]]) evals.record.record_metrics(accuracy=float(includes_answer)) return includes_answer def run(self, recorder): - samples = evals.get_jsonl(self.samples_jsonl) + samples = self.get_samples() self.eval_all_samples(recorder, samples) events = recorder.get_scores("accuracy") return { diff --git a/evals/elsuite/basic/match.py b/evals/elsuite/basic/match.py index ecd5092ac6..289a1758dd 100644 --- a/evals/elsuite/basic/match.py +++ b/evals/elsuite/basic/match.py @@ -2,13 +2,14 @@ import evals import evals.metrics +from evals.api import CompletionFn from evals.prompt.base import is_chat_prompt class Match(evals.Eval): def __init__( self, - model_specs: evals.ModelSpecs, + completion_fns: list[CompletionFn], samples_jsonl: str, *args, max_tokens: int = 500, @@ -16,7 +17,8 @@ def __init__( few_shot_jsonl: str = None, **kwargs, ): - super().__init__(model_specs, *args, **kwargs) + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "Match only supports one completion fn" self.max_tokens = max_tokens self.samples_jsonl = samples_jsonl self.num_few_shot = num_few_shot @@ -34,10 +36,20 @@ def eval_sample(self, sample: Any, *_): prompt += s["sample"] prompt += sample["input"][-1:] - return evals.check_sampled_text(self.model_spec, prompt, expected=sample["ideal"]) + result = self.completion_fn( + prompt=prompt, + temperature=0.0, + ) + sampled = result.get_completions()[0] + + return evals.record_and_check_match( + prompt=prompt, + sampled=sampled, + expected=sample["ideal"], + ) def run(self, recorder): - samples = evals.get_jsonl(self.samples_jsonl) + samples = self.get_samples() self.eval_all_samples(recorder, samples) events = recorder.get_events("match") return { diff --git a/evals/elsuite/modelgraded/classify.py b/evals/elsuite/modelgraded/classify.py index 454d0b7ab4..3f12b34daf 100644 --- a/evals/elsuite/modelgraded/classify.py +++ b/evals/elsuite/modelgraded/classify.py @@ -10,7 +10,7 @@ import evals import evals.record -from evals.base import ModelSpec +from evals import CompletionFn, DummyCompletionFn, OpenAIChatCompletionFn from evals.elsuite.modelgraded.base import ModelGradedSpec from evals.elsuite.modelgraded.classify_utils import ( CHOICE_KEY, @@ -28,7 +28,7 @@ class ModelBasedClassify(evals.Eval): def __init__( self, - model_specs: evals.ModelSpecs, + completion_fns: list[CompletionFn], samples_jsonl: str, modelgraded_spec: str, *args, @@ -38,13 +38,12 @@ def __init__( multicomp_temperature: float = 0.4, samples_renamings: Optional[dict[str, str]] = None, eval_type: Optional[str] = None, - eval_model: str = "gpt-3.5-turbo", metaeval: bool = False, modelgraded_spec_args: Optional[dict[str, dict[str, str]]] = None, **kwargs, ): - super().__init__(model_specs, *args, **kwargs) - n_models = len(self.model_specs.completions) + super().__init__(completion_fns, *args, **kwargs) + n_models = len(self.completion_fns) self.max_tokens = max_tokens self.samples_jsonl = samples_jsonl self.match_fn = MATCH_FNS[match_fn] @@ -61,15 +60,15 @@ def __init__( self.samples_renamings = samples_renamings or {} # check if multiple models are specified - if len(self.model_specs.completions) > 1: + if len(self.completion_fns) > 1: assert ( self.multicomp_n == n_models - ), f"multicomp_n={self.multicomp_n} must be equal to the number of models={len(self.model_specs.completions)} if multiple models are specified." + ), f"multicomp_n={self.multicomp_n} must be equal to the number of models={len(self.completion_fns)} if multiple models are specified." - if self.model_spec.name == "dummy-completion" or self.model_spec.name == "dummy-chat": - self.eval_modelspec = self.model_spec + if isinstance(self.completion_fn, DummyCompletionFn): + self.eval_completion_fn = self.completion_fn else: - self.eval_modelspec = ModelSpec(name=eval_model, model=eval_model, is_chat=True) + self.eval_completion_fn = OpenAIChatCompletionFn(model="gpt-3.5-turbo") spec_kwargs = {"multicomp_n": self.multicomp_n} if modelgraded_spec_args: @@ -108,15 +107,15 @@ def eval_sample(self, test_sample: dict, rng: Random) -> None: if self.multicomp_n > 1 and v in self.mg.completion_sample_templates: completion_i_s = [] for i in range(self.multicomp_n): - if len(self.model_specs.completions) > 1: + if len(self.completion_fns) > 1: # use a separate model for each completion - model_spec = self.model_specs.completions[i] + completion_fn = self.completion_fns[i] else: # use the single model for all completions - model_spec = self.model_spec + completion_fn = self.completion_fn get_input_completion = PromptFn( test_sample[k], - model_spec=model_spec, + completion_fn=completion_fn, max_tokens=self.max_tokens, temperature=self.multicomp_temperature, ) @@ -128,7 +127,7 @@ def eval_sample(self, test_sample: dict, rng: Random) -> None: else: get_input_completion = PromptFn( test_sample[k], - model_spec=self.model_spec, + completion_fn=self.completion_fn, max_tokens=self.max_tokens, ) completion, _ = get_input_completion() @@ -150,7 +149,7 @@ def eval_sample(self, test_sample: dict, rng: Random) -> None: prompt = self.mg.format(**args, **completions, **test_sample) evaluate = PromptFn( prompt, - model_spec=self.eval_modelspec, + completion_fn=self.eval_completion_fn, max_tokens=self.max_tokens, ) try: @@ -178,7 +177,7 @@ def eval_sample(self, test_sample: dict, rng: Random) -> None: return choice def run(self, recorder): - samples = evals.get_jsonl(self.samples_jsonl) + samples = self.get_samples() self.eval_all_samples(recorder, samples) record_metrics = {} diff --git a/evals/elsuite/translate.py b/evals/elsuite/translate.py index 42cf8c7784..40f381d81e 100644 --- a/evals/elsuite/translate.py +++ b/evals/elsuite/translate.py @@ -4,13 +4,14 @@ import evals import evals.metrics +from evals.api import CompletionFn from evals.prompt.base import is_chat_prompt class Translate(evals.Eval): def __init__( self, - model_specs: evals.ModelSpecs, + completion_fns: list[CompletionFn], samples_jsonl: str, *args, max_tokens: int = 500, @@ -18,7 +19,8 @@ def __init__( few_shot_jsonl: str = None, **kwargs, ): - super().__init__(model_specs, *args, **kwargs) + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "Translate only supports one completion fn" self.max_tokens = max_tokens self.samples_jsonl = samples_jsonl @@ -45,7 +47,11 @@ def eval_sample(self, sample: Any, *_): elif not isinstance(expected, list): expected = [expected] - sampled = evals.sample_freeform(self.model_spec, prompt, max_tokens=self.max_tokens) + result = self.completion_fn( + prompt=prompt, + max_tokens=self.max_tokens, + ) + sampled = result.get_completions()[0] score = None if expected is not None: @@ -61,7 +67,7 @@ def eval_sample(self, sample: Any, *_): return match def run(self, recorder): - samples = evals.get_jsonl(self.samples_jsonl) + samples = self.get_samples() self.eval_all_samples(recorder, samples) events = recorder.get_events("match") diff --git a/evals/elsuite/utils.py b/evals/elsuite/utils.py index 5a83cd2017..cc4f327741 100644 --- a/evals/elsuite/utils.py +++ b/evals/elsuite/utils.py @@ -2,9 +2,16 @@ import re import string from collections import Counter, defaultdict +from typing import Optional, Union -from evals.api import sample_freeform -from evals.prompt.base import OpenAICreatePrompt, chat_prompt_to_text_prompt, is_chat_prompt +from evals import CompletionFn +from evals.prompt.base import ( + OpenAICreateChatPrompt, + OpenAICreatePrompt, + Prompt, + chat_prompt_to_text_prompt, + is_chat_prompt, +) def get_answer(text, answer_prompt): @@ -122,27 +129,52 @@ def format_prompt(prompt: OpenAICreatePrompt, **kwargs: dict[str, str]) -> OpenA class PromptFn: - """Wrap calls to model with prompt""" - - def __init__(self, prompt, model_spec, max_tokens, temperature=0, completion_kwargs=None): + """ + Wrap calls to a completion_fn with a prompt template with applicable keyword args. + This will pass many args relevant to OpenAI Completion API, may be ignored by other completion_fn. + """ + + def __init__( + self, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + completion_fn: CompletionFn, + max_tokens: int, + temperature: int = 0, + n_samples: Optional[int] = None, + completion_kwargs: Optional[dict] = {}, + ): self.prompt = prompt self.max_tokens = max_tokens - self.model_spec = model_spec + self.completion_fn = completion_fn self.temperature = temperature - self.completion_kwargs = completion_kwargs or {} - - def __call__(self, skip_format: bool = False, **kwargs): - prompt = self.prompt - if not skip_format: - prompt = format_prompt(prompt, **kwargs) - completion = sample_freeform( - self.model_spec, - prompt, + self.completion_kwargs = completion_kwargs + self.n_samples = n_samples + + def __call__(self, **kwargs): + # if any input kwargs is chat prompt, convert to text prompt + kwargs = { + k: chat_prompt_to_text_prompt(v) if is_chat_prompt(v) else v for k, v in kwargs.items() + } + if is_chat_prompt(self.prompt): + prompt = [] + for msg in self.prompt: + formatted_msg = copy.copy(msg) + if "content" in formatted_msg: + formatted_msg["content"] = format_necessary(formatted_msg["content"], **kwargs) + prompt.append(formatted_msg) + else: + # Prompt is a string + prompt = format_necessary(self.prompt, **kwargs) + + result = self.completion_fn( + prompt=prompt, max_tokens=self.max_tokens, temperature=self.temperature, top_p=1, frequency_penalty=0, presence_penalty=0, + n=(1 if self.n_samples is None else self.n_samples), **self.completion_kwargs, ) - return completion, prompt + sampled = result.get_completions()[0] + return sampled, prompt diff --git a/evals/eval.py b/evals/eval.py index 845123e0cf..7e9a1775be 100644 --- a/evals/eval.py +++ b/evals/eval.py @@ -12,9 +12,11 @@ from tqdm import tqdm -from .base import ModelSpec, ModelSpecs +from evals.api import CompletionFn + from .record import RecorderBase from .registry import Registry +from .data import get_jsonl logger = logging.getLogger(__name__) @@ -51,32 +53,29 @@ class Eval(abc.ABC): def __init__( self, - model_specs: ModelSpecs, + completion_fns: list[CompletionFn], seed: int = 20220722, name: str = "no_name_eval.default", registry: Optional[Registry] = None, + samples_jsonl: Optional[str] = None, ): splits = name.split(".") if len(splits) < 2: raise ValueError(f"Eval name must at least have .. Got name {name}") - self.model_specs = model_specs + self.completion_fns = completion_fns self.seed = seed self.name = name self.registry = registry or Registry() + self.samples_jsonl = samples_jsonl def eval_sample(self, sample: Any, rng: random.Random): raise NotImplementedError() - @classmethod - def create_and_run(cls, model_specs: ModelSpecs, *args, **kwargs) -> Dict[str, float]: - logging.info(f"Running {cls.__name__} with {model_specs}, args: {args}, kwargs: {kwargs}") - return cls(model_specs).run(*args, **kwargs) - @property - def model_spec(self) -> ModelSpec: - """Helper for more ergonomic access to a single model.""" - return self.model_specs.completion + def completion_fn(self) -> CompletionFn: + """Helper for more ergonomic access to a single CompletionFn.""" + return self.completion_fns[0] @abc.abstractmethod def run(self, recorder: RecorderBase) -> Dict[str, float]: @@ -109,6 +108,7 @@ def eval_all_samples( recorder: RecorderBase, samples, show_progress=True, + record_raw_sample=True, ): """ Evaluate all provided samples in parallel. @@ -126,7 +126,6 @@ def eval_sample(args): base_name, split = self.name.split(".")[0:2] sample_id = f"{base_name}.{split}.{idx}" with recorder.as_default_recorder(sample_id): - recorder.record_raw(sample) seed = f"{sample_id}:{self.seed}".encode("utf-8") rng = random.Random(seed) return idx, self.eval_sample(sample, rng) @@ -153,3 +152,11 @@ def worker_thread(args): iter = pool.imap_unordered(worker_thread, work_items) idx_and_result = list(tqdm(iter, total=len(work_items), disable=not show_progress)) return [r for _, r in sorted(idx_and_result)] + + def get_samples(self): + if self.samples_jsonl is None: + raise ValueError( + "To use `get_samples`, you must provide a `samples_jsonl` path." + "Got `None`.") + + return get_jsonl(self.samples_jsonl) diff --git a/evals/prompt/base.py b/evals/prompt/base.py index 7c2aa3be04..93eb84f8fb 100644 --- a/evals/prompt/base.py +++ b/evals/prompt/base.py @@ -64,10 +64,9 @@ class Prompt(ABC): """ @abstractmethod - def to_openai_create_prompt(self): + def to_formatted_prompt(self): """ - Return the actual data to be passed as the `prompt` field to either `openai.ChatCompletion.create`, - if the model is a chat model, or `openai.Completion.create` otherwise. + Return the actual data to be passed as the `prompt` field to your model. See the above types to see what each API call is able to handle. """ @@ -82,12 +81,12 @@ class CompletionPrompt(Prompt): A `Prompt` object that wraps prompts to be compatible with non chat models, which use `openai.Completion.create`. """ - raw_prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt] + raw_prompt: Union[str, OpenAICreateChatPrompt] - def _render_chat_prompt_as_text(self, prompt: OpenAICreateChatPrompt) -> OpenAICreatePrompt: + def _render_chat_prompt_as_text(self, prompt: OpenAICreateChatPrompt) -> str: return chat_prompt_to_text_prompt(prompt) - def to_openai_create_prompt(self) -> OpenAICreatePrompt: + def to_formatted_prompt(self) -> str: if is_chat_prompt(self.raw_prompt): return self._render_chat_prompt_as_text(self.raw_prompt) return self.raw_prompt @@ -110,7 +109,7 @@ def _render_text_as_chat_prompt(self, prompt: str) -> OpenAICreateChatPrompt: """ return text_prompt_to_chat_prompt(prompt) - def to_openai_create_prompt(self) -> OpenAICreateChatPrompt: + def to_formatted_prompt(self) -> OpenAICreateChatPrompt: if is_chat_prompt(self.raw_prompt): return self.raw_prompt return self._render_text_as_chat_prompt(self.raw_prompt) diff --git a/evals/record.py b/evals/record.py index d02ab20d5f..0c4823950c 100644 --- a/evals/record.py +++ b/evals/record.py @@ -329,7 +329,7 @@ class Recorder(RecorderBase): def __init__( self, log_path: Optional[str], - run_spec: evals.base.RunSpec, + run_spec: RunSpec, snowflake_connection: Optional[SnowflakeConnection] = None, ) -> None: super().__init__(run_spec) @@ -353,7 +353,8 @@ def __init__( command=query, params={ "run_id": run_spec.run_id, - "model_name": jsondumps(run_spec.model_names), + # TODO: model_name -> completion_fns + "model_name": jsondumps(run_spec.completion_fns), "eval_name": run_spec.eval_name, "base_eval": run_spec.base_eval, "split": run_spec.split, diff --git a/evals/registry.py b/evals/registry.py index 89ce84dc9f..9e91ccbc7b 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -10,13 +10,16 @@ import logging import os import re -from functools import partial +from functools import cached_property from pathlib import Path -from typing import Any, Iterator, Sequence, Type, Union +from typing import Any, Iterator, Optional, Sequence, Type, Union +import openai import yaml -from evals.base import BaseEvalSpec, EvalSetSpec, EvalSpec, ModelSpec +from evals import OpenAIChatCompletionFn, OpenAICompletionFn +from evals.api import CompletionFn, DummyCompletionFn +from evals.base import BaseEvalSpec, CompletionFnSpec, EvalSetSpec, EvalSpec from evals.elsuite.modelgraded.base import ModelGradedSpec from evals.utils.misc import make_object @@ -25,12 +28,87 @@ DEFAULT_PATHS = [Path(__file__).parents[0].resolve() / "registry", Path.home() / ".evals"] +def n_ctx_from_model_name(model_name: str) -> Optional[int]: + """Returns n_ctx for a given API model name. Model list last updated 2023-03-14.""" + # note that for most models, the max tokens is n_ctx + 1 + DICT_OF_N_CTX_BY_MODEL_NAME_PREFIX: dict[str, int] = { + "gpt-3.5-turbo-": 4096, + "gpt-4-": 8192, + "gpt-4-32k-": 32768, + } + DICT_OF_N_CTX_BY_MODEL_NAME: dict[str, int] = { + "ada": 2048, + "text-ada-001": 2048, + "babbage": 2048, + "text-babbage-001": 2048, + "curie": 2048, + "text-curie-001": 2048, + "davinci": 2048, + "text-davinci-001": 2048, + "code-davinci-002": 8000, + "text-davinci-002": 4096, + "text-davinci-003": 4096, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0301": 4096, + "gpt-4": 8192, + "gpt-4-0314": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0314": 32768, + } + # first, look for a prefix match + for model_prefix, n_ctx in DICT_OF_N_CTX_BY_MODEL_NAME_PREFIX.items(): + if model_name.startswith(model_prefix): + return n_ctx + # otherwise, look for an exact match and return None if not found + return DICT_OF_N_CTX_BY_MODEL_NAME.get(model_name, None) + + class Registry: def __init__(self, registry_paths: Sequence[Union[str, Path]] = DEFAULT_PATHS): self._registry_paths = [Path(p) if isinstance(p, str) else p for p in registry_paths] - def make_callable(self, spec): - return partial(make_object(spec.cls).create_and_run, **(spec.args or {})) + def add_registry_paths(self, paths: list[Union[str, Path]]): + self._registry_paths.extend([Path(p) if isinstance(p, str) else p for p in paths]) + + @cached_property + def api_model_ids(self): + return [m["id"] for m in openai.Model.list()["data"]] + + def make_completion_fn(self, name: str) -> CompletionFn: + """ + Create a CompletionFn. The name can be one of the following formats: + 1. openai-model-id (e.g. "gpt-3.5-turbo") + 2. completion-fn-id (from the registry) + """ + + if name == "dummy": + return DummyCompletionFn() + + n_ctx = n_ctx_from_model_name(name) + + CHAT_MODELS = { + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-4", + "gpt-4-0314", + "gpt-4-32k", + "gpt-4-32k-0314", + } + + if name in CHAT_MODELS: + return OpenAIChatCompletionFn(model=name, n_ctx=n_ctx) + elif name in self.api_model_ids: + return OpenAICompletionFn(model=name, n_ctx=n_ctx) + + # No match, so try to find a completion-fn-id in the registry + spec = self.get_completion_fn(name) + if spec is None: + raise ValueError(f"Could not find CompletionFn in the registry with ID {name}") + + spec.args["registry"] = self + instance = make_object(spec.cls)(**spec.args or {}) + assert isinstance(instance, CompletionFn), f"{name} must be a CompletionFn" + return instance def get_class(self, spec: dict) -> Any: return make_object(spec.cls, **(spec.args if spec.args else {})) @@ -70,14 +148,18 @@ def get_alias(): except TypeError as e: raise TypeError(f"Error while processing {object} '{name}': {e}") - def get_model(self, name: str) -> ModelSpec: - return self._dereference(name, self._models, "model", ModelSpec) - def get_modelgraded_spec(self, name: str, **kwargs: dict) -> dict[str, Any]: + assert name in self._modelgraded_specs, ( + f"Modelgraded spec {name} not found. " + f"Closest matches: {difflib.get_close_matches(name, self._modelgraded_specs.keys(), n=5)}" + ) return self._dereference( name, self._modelgraded_specs, "modelgraded spec", ModelGradedSpec, **kwargs ) + def get_completion_fn(self, name: str) -> CompletionFnSpec: + return self._dereference(name, self._completion_fns, "completion_fn", CompletionFnSpec) + def get_eval(self, name: str) -> EvalSpec: return self._dereference(name, self._evals, "eval", EvalSpec) @@ -170,6 +252,10 @@ def _load_registry(self, paths): self._process_file(registry, path) return registry + @functools.cached_property + def _completion_fns(self): + return self._load_registry([p / "completion_fns" for p in self._registry_paths]) + @functools.cached_property def _eval_sets(self): return self._load_registry([p / "eval_sets" for p in self._registry_paths]) @@ -182,9 +268,5 @@ def _evals(self): def _modelgraded_specs(self): return self._load_registry([p / "modelgraded" for p in self._registry_paths]) - @functools.cached_property - def _models(self): - return self._load_registry([p / "models" for p in self._registry_paths]) - registry = Registry() diff --git a/evals/registry/completion_fns/cot.yaml b/evals/registry/completion_fns/cot.yaml new file mode 100644 index 0000000000..66c1ad4fc8 --- /dev/null +++ b/evals/registry/completion_fns/cot.yaml @@ -0,0 +1,14 @@ +cot/text-davinci-003: + class: evals.completion_fns.cot:ChainOfThoughtCompletionFn + args: + cot_completion_fn: text-davinci-003 + +cot/gpt-3.5-turbo: + class: evals.completion_fns.cot:ChainOfThoughtCompletionFn + args: + cot_completion_fn: gpt-3.5-turbo + +cot/flan-t5-xl: + class: evals.completion_fns.cot:ChainOfThoughtCompletionFn + args: + cot_completion_fn: langchain/llm/flan-t5-xl diff --git a/evals/registry/completion_fns/langchain_llms.yaml b/evals/registry/completion_fns/langchain_llms.yaml new file mode 100644 index 0000000000..1e01d75e3d --- /dev/null +++ b/evals/registry/completion_fns/langchain_llms.yaml @@ -0,0 +1,20 @@ +langchain/llm/gpt-3.5-turbo: + class: evals.completion_fns.langchain_llm:LangChainLLMCompletionFn + args: + llm: OpenAI + llm_kwargs: + model_name: gpt-3.5-turbo + +langchain/llm/text-davinci-003: + class: evals.completion_fns.langchain_llm:LangChainLLMCompletionFn + args: + llm: OpenAI + llm_kwargs: + model_name: text-davinci-003 + +langchain/llm/flan-t5-xl: + class: evals.completion_fns.langchain_llm:LangChainLLMCompletionFn + args: + llm: HuggingFaceHub + llm_kwargs: + repo_id: google/flan-t5-xl diff --git a/evals/utils/api_utils.py b/evals/utils/api_utils.py index ecddfde0d2..751440e311 100644 --- a/evals/utils/api_utils.py +++ b/evals/utils/api_utils.py @@ -6,42 +6,6 @@ import backoff import openai - -def generate_dummy_chat_completion(): - return { - "id": "dummy-id", - "object": "chat.completion", - "created": 12345, - "model": "dummy-chat", - "usage": {"prompt_tokens": 56, "completion_tokens": 6, "total_tokens": 62}, - "choices": [ - { - "message": {"role": "assistant", "content": "This is a dummy response."}, - "finish_reason": "stop", - "index": 0, - } - ], - } - - -def generate_dummy_completion(): - return { - "id": "dummy-id", - "object": "text_completion", - "created": 12345, - "model": "dummy-completion", - "choices": [ - { - "text": "This is a dummy response.", - "index": 0, - "logprobs": None, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 5, "completion_tokens": 6, "total_tokens": 11}, - } - - @backoff.on_exception( wait_gen=backoff.expo, exception=( @@ -59,9 +23,6 @@ def openai_completion_create_retrying(*args, **kwargs): Helper function for creating a completion. `args` and `kwargs` match what is accepted by `openai.Completion.create`. """ - if kwargs["model"] == "dummy-completion": - return generate_dummy_completion() - result = openai.Completion.create(*args, **kwargs) if "error" in result: logging.warning(result) @@ -86,9 +47,6 @@ def openai_chat_completion_create_retrying(*args, **kwargs): Helper function for creating a chat completion. `args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`. """ - if kwargs["model"] == "dummy-chat": - return generate_dummy_chat_completion() - result = openai.ChatCompletion.create(*args, **kwargs) if "error" in result: logging.warning(result) diff --git a/pyproject.toml b/pyproject.toml index 26b226e2e4..9a7fc04b61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "evals" -version = "0.1.1" +version = "1.0.0" requires-python = ">=3.9" dependencies = [ "mypy", @@ -24,6 +24,7 @@ dependencies = [ "pyyaml", "sacrebleu", "matplotlib", + "setuptools_scm", ] [project.scripts]