-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[evals] Refactor evals package to expose completion_fn
.
#515
Changes from 1 commit
d87a056
d9c1395
a1c6207
deb29d3
9b1c350
c470d52
b691cfa
7266049
b2a45cf
924d2d4
4401cce
013d636
08062bc
3367006
5e71a76
e621b6f
49d17ed
1bfba77
b018aff
5222f2c
9db703d
02bc2cb
50114a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
completion_fn
.
PAIR=jasonwei Co-authored-by: Jason Wei <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .api import check_sampled_text, completion_query, sample_freeform | ||
from .api import check_sampled_text, completion_query, sample_freeform, postprocess_sample_freeform, record_and_check_match | ||
from .base import ModelSpec, ModelSpecs | ||
from .data import get_csv, get_json, get_jsonl, get_jsonls, get_lines, iter_jsonls | ||
from .eval import Eval |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,7 @@ def completion_query( | |
return result, openai_create_prompt, metadata | ||
|
||
|
||
# TODO(hwc): remove this | ||
def check_sampled_text( | ||
model_spec: ModelSpec, | ||
prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], | ||
|
@@ -123,13 +124,6 @@ def check_sampled_text( | |
======= | ||
The option that was picked, i.e., matched the completion, or None. | ||
""" | ||
if isinstance(expected, tuple): | ||
expected = list(expected) | ||
elif not isinstance(expected, list): | ||
expected = [expected] | ||
if options is None: | ||
options = expected | ||
|
||
result, actual_prompt, metadata = completion_query( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to rewrite completion_query with new fn? |
||
prompt=prompt, | ||
temperature=0.0, | ||
|
@@ -139,6 +133,31 @@ def check_sampled_text( | |
|
||
sampled = choice["text"].strip() if model_spec.strip_completion else choice["text"] | ||
|
||
return record_and_check_match( | ||
prompt=actual_prompt, | ||
sampled=sampled, | ||
expected=expected, | ||
metadata=metadata, | ||
separator=separator, | ||
options=options, | ||
) | ||
|
||
|
||
def record_and_check_match( | ||
prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt], | ||
sampled: str, | ||
expected: Union[str, list[str], tuple[str]], | ||
metadata: dict, | ||
separator: Callable[[str], bool] = None, | ||
options: Optional[list[str]] = None, | ||
): | ||
if isinstance(expected, tuple): | ||
expected = list(expected) | ||
elif not isinstance(expected, list): | ||
expected = [expected] | ||
if options is None: | ||
options = expected | ||
|
||
picked = None | ||
for option in options: | ||
if not sampled.startswith(option): | ||
|
@@ -153,7 +172,7 @@ def check_sampled_text( | |
break | ||
|
||
result = { | ||
"prompt": actual_prompt, | ||
"prompt": prompt, | ||
"sampled": sampled, | ||
"options": options, | ||
"picked": picked, | ||
|
@@ -175,7 +194,7 @@ def sample_freeform( | |
top_p: float = 0.9, | ||
max_tokens: int = 512, | ||
stop: Optional[str] = None, | ||
n_samples: int = None, | ||
n_samples: Optional[int] = None, | ||
return_logprobs: bool = False, | ||
**kwargs, | ||
) -> Union[str, list[str], dict]: | ||
|
@@ -215,10 +234,51 @@ def sample_freeform( | |
headers={}, | ||
**kwargs, | ||
) | ||
return postprocess_sample_freeform( | ||
response, | ||
actual_prompt, | ||
metadata, | ||
model_spec, | ||
n_samples=n_samples, | ||
return_logprobs=return_logprobs, | ||
**kwargs) | ||
|
||
|
||
def postprocess_sample_freeform( | ||
jwang47 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
response: dict, | ||
prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], | ||
metadata: dict, | ||
model_spec: ModelSpec, | ||
*, | ||
n_samples: Optional[int] = None, | ||
return_logprobs: bool = False, | ||
**kwargs, | ||
) -> Union[str, list[str], dict]: | ||
""" | ||
Records the sampled response, prompt and metedata, and returns the sampled text. | ||
Typically called after `sample_freeform`. | ||
|
||
ARGS | ||
==== | ||
`response`: The result of the API call. | ||
`prompt`: See `completion_query`. | ||
`n_samples`: The number of samples to generate (1 if None). | ||
`return_logprobs`: If True, returns the tokens and corresponding logprobs | ||
in addition to the sampled text. | ||
`kwargs`: See `completion_query`. | ||
|
||
RETURNS | ||
======= | ||
If `return_logprobs` is True, returns a dict with the sampled text, tokens, | ||
and corresponding logprobs. If `n_samples` is None, the outer list is | ||
removed from all values. | ||
Otherwise, returns the sampled text, or a list of sampled texts if | ||
`n_samples` is not None. | ||
""" | ||
sampled = [choice["text"] for choice in response["choices"]] | ||
if n_samples is None: | ||
sampled = sampled[0] | ||
record_sampling(prompt=actual_prompt, sampled=sampled, metadata=metadata) | ||
record_sampling(prompt=prompt, sampled=sampled, metadata=metadata) | ||
|
||
if return_logprobs: | ||
assert not model_spec.is_chat, "logprobs only works for non-chat models" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,19 @@ | |
import re | ||
import string | ||
from collections import Counter, defaultdict | ||
from typing import Union | ||
from typing_extensions import Protocol | ||
|
||
from evals.api import sample_freeform | ||
from evals.prompt.base import chat_prompt_to_text_prompt, is_chat_prompt | ||
|
||
from evals.base import ModelSpec | ||
from evals.prompt.base import ( | ||
OpenAICreateChatPrompt, | ||
OpenAICreatePrompt, | ||
Prompt, | ||
) | ||
|
||
|
||
def get_answer(text, answer_prompt): | ||
idx = text.rfind(answer_prompt) | ||
|
@@ -135,3 +144,31 @@ def __call__(self, **kwargs): | |
**self.completion_kwargs, | ||
) | ||
return completion, prompt | ||
|
||
|
||
class CompletionFn(Protocol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general I like CompletionFn but there is some organization to be done, for example:
Let's discuss? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally I'd like to have evals only use CompletionFn as opposed to picking between CompletionFn and completion_query (more accurately openai_completion_query as @andrew-openai pointed out). Also happy to discuss if needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One more thing to consider: We need CompletionFn subclasses to probably support both chat and non-chat inputs, which means implementing some generic casting behavior to go from chat to non-chat. I think luckily we have a lot of this already, which is implemented in PromptFn and chat_prompt_to_text_prompt, but just need to add it to CompletionFn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The refactor needs to extend through the codebase, something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is my current proposal for making
|
||
|
||
def __call__( | ||
self, | ||
model_spec: ModelSpec, | ||
prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], | ||
**kwargs | ||
) ->tuple[dict, Union[OpenAICreatePrompt, OpenAICreateChatPrompt], dict]: | ||
jwang47 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
ARGS | ||
==== | ||
`model_spec`: `ModelSpec` containing model details to use in the query. | ||
This should be the dict returned by `registry.get_model()`. | ||
If `model_spec` is not provided, we use the default model that was | ||
intialized at the beginning of the run. | ||
`prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in | ||
the approriate `Prompt` class. | ||
`kwargs`: Other arguments passed to the API. | ||
|
||
RETURNS | ||
======= | ||
The result of the API call. | ||
The prompt that was fed into the API call as a str. | ||
A dict containing metadata about the query. | ||
""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we concluded we'll keep it but refactor it to use new fns?