From e8c09c867aa4f46cd409a40795cfdfd3d349032a Mon Sep 17 00:00:00 2001 From: Shane Gu <343165+rlbayes@users.noreply.github.com> Date: Mon, 3 Apr 2023 21:35:51 -0700 Subject: [PATCH] [evals] refactored modelgraded eval (#578) - moved functions to modelgraded/classify_utils.py - defined ModelGradedSpec and moved to modelgraded/base.py - unified interface on registry.py - other misc refactoring --- evals/elsuite/modelgraded/base.py | 77 +++++ evals/elsuite/modelgraded/classify.py | 279 +++++------------- evals/elsuite/modelgraded/classify_utils.py | 90 ++++++ evals/elsuite/utils.py | 50 ++-- evals/prompt/base.py | 6 +- evals/registry.py | 24 +- .../registry/eval_sets/test-modelgraded.yaml | 3 +- 7 files changed, 287 insertions(+), 242 deletions(-) create mode 100644 evals/elsuite/modelgraded/base.py create mode 100644 evals/elsuite/modelgraded/classify_utils.py diff --git a/evals/elsuite/modelgraded/base.py b/evals/elsuite/modelgraded/base.py new file mode 100644 index 0000000000..968ac35ad7 --- /dev/null +++ b/evals/elsuite/modelgraded/base.py @@ -0,0 +1,77 @@ +import string +from typing import TYPE_CHECKING, Optional, Union + +from evals.elsuite.modelgraded.classify_utils import ANSWER_PROMPTS, choice_to_str, expand_args_dict +from evals.prompt.base import OpenAICreateChatPrompt + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from pydantic.dataclasses import dataclass + + +@dataclass +class ModelGradedSpec: + prompt: Union[str, OpenAICreateChatPrompt] + choice_strings: Union[list[str], str] + eval_type: str + input_outputs: dict[str, str] + + choice_scores: Optional[Union[dict[str, Union[float, int]], str]] = None + multicomp_n: Optional[int] = None + append_answer_prompt: bool = False + args: Optional[dict[str, dict[str, str]]] = None + expand_args_dict: Optional[dict[str, dict[str, tuple[str]]]] = None + completion_sample_templates: Optional[dict[str, str]] = None + + key: Optional[str] = None # unused + group: Optional[str] = None # unused + + def __post_init__(self): + # 'choice_strings' is a list of strings that specifies the possible choices + if self.choice_strings == "from_n": + self.choice_strings = [str(i + 1) for i in range(self.multicomp_n)] + elif self.choice_strings == "from_n_abc": + self.choice_strings = [string.ascii_lowercase[i % 26] for i in range(self.multicomp_n)] + elif self.choice_strings == "from_n_ABC": + self.choice_strings = [string.ascii_uppercase[i % 26] for i in range(self.multicomp_n)] + # make sure each choice doesn't contain any punctuation + for s in self.choice_strings: + assert not any(c in s for c in string.punctuation), f"{s} contains punctuation" + + # (optional) 'choice_scores' is a dict that specifies the score for each choice string + # if 'choice_scores' is specified, 'scores/' are computed and added to metrics + if self.choice_scores: + if self.choice_scores == "from_strings": + self.choice_scores = {c: float(c) for c in self.choice_strings} + + # 'prompt' is a string that specifies the model-graded evaluation + assert isinstance(self.prompt, str), f"prompt must be a string, not {type(self.prompt)}" + if self.append_answer_prompt: + self.prompt += "\n\n" + ANSWER_PROMPTS[self.eval_type].format( + choices=choice_to_str(self.choice_strings) + ) + self.prompt = [{"role": "user", "content": self.prompt}] + + # 'input_outputs' is a dict that specifies the input and output keys in the sample + # output key is the model's raw response to input key. These are used for filling 'prompt' template. + assert isinstance( + self.input_outputs, dict + ), f"input_outputs must be a dict, not {type(self.input_outputs)}" + + # (optional) 'args' is a dict of dicts that specifies additional arguments for 'prompt' + # each value in 'args' essentially defines a separate modelgraded classification eval and has own metrics! + self.args = self.args or {} + self.expanded_args_dict = expand_args_dict(self.args) + + # (optional) 'completion_sample_templates' + # each key must be one of 'input_outputs'.values(). If 'multicomp_n' > 1, this template is filled 'multicomp_n' times + # and the concatenated result is passed to 'prompt' template. + self.completion_sample_templates = self.completion_sample_templates or {} + assert all( + k in self.input_outputs.values() for k in self.completion_sample_templates + ), f"all {self.completion_sample_templates.keys()} must be in {self.input_outputs.values()}, " + if self.multicomp_n > 1: + assert ( + self.completion_sample_templates + ), "completion_sample_templates must be specified if multicomp_n > 1" diff --git a/evals/elsuite/modelgraded/classify.py b/evals/elsuite/modelgraded/classify.py index fc6911d270..d9d22bfe33 100644 --- a/evals/elsuite/modelgraded/classify.py +++ b/evals/elsuite/modelgraded/classify.py @@ -1,87 +1,25 @@ """ Generic eval that uses a prompt + classification. """ -import copy -import itertools import logging -import string from collections import Counter from random import Random -from typing import Callable, Iterable, Optional, Union +from typing import Optional, Union import openai import evals import evals.record from evals.base import ModelSpec -from evals.elsuite.utils import PromptFn, format_necessary, scrub_formatting_from_prompt - -INVALID_STR = "__invalid__" -CHOICE_KEY = "choice" -MATCH_FNS = { - "include": lambda x, y: float(x in y), - "exact": lambda x, y: float(x == y), - "endswith": lambda x, y: x.endswith(y), - "starts_or_endswith": lambda x, y: x.startswith(y) or x.endswith(y), -} - -ANSWER_PROMPTS = { - # e.g. "Yes" - "classify": "Answer the question by printing only a single choice from {choices} (without quotes or punctuation) corresponding to the correct answer with no other text.".strip(), - # e.g. "Yes\n The reasons are: ..." - "classify_cot": "First, answer by printing a single choice from {choices} (without quotes or punctuation) corresponding to the correct answer. Then, from the next line, explain your reasonings step by step.".strip(), - # e.g. "Let's think step by step. ...\nYes" - "cot_classify": """ -First, write out in a step by step manner your reasoning to be sure that your conclusion is correct. Avoid simply stating the correct answer at the outset. Then print only a single choice from {choices} (without quotes or punctuation) on its own line corresponding to the correct answer. At the end, repeat just the answer by itself on a new line. - -Reasoning:""".strip(), - "cot_classify_jp": """ -まず、一歩一歩あなたの推論を書き出してください。単に正しい答えを最初に述べることを避けてください。次に、{choices}(引用符や句読点なし)から正しい答えに対応する1つの選択肢を単独の行に書きだしてください。最後に、答えだけを新しい行に繰り返してください。 - -推論: - """.strip(), -} - - -def choice_to_str(choice_strings: Iterable[str]) -> str: - """Return a string of choices, e.g. '"Yes" or "No" or "Maybe"'.""" - return " or ".join(f'"{choice}"' for choice in choice_strings) - - -def get_choice(text: str, eval_type: str, match_fn: Callable, choice_strings: Iterable[str]) -> str: - """Clean the answer string to a choice string to one of choice_strings. Return '__invalid__.' if no match.""" - lines = text.strip().split("\n") - if eval_type.startswith("cot_classify"): - lines = lines[::-1] # reverse lines - for line in lines: - line = line.strip() - line = "".join(c for c in line if c not in string.punctuation) - if not line: - continue - for choice in choice_strings: - if match_fn(line, choice): - return choice - return INVALID_STR - - -def expand_args_dict(args_dict): - """Expand a dict of dicts, with namings. - - args_dict = { - "a": {"a1": 1, "a2": 2}, - "b": {"b1": 3, "b2": 4}, - } - expand_args_dict(args_dict) = { - "a=a1:b=b1": {"a": ("a1", 1), "b": ("b1", 3)}, - "a=a1:b=b2": {"a": ("a1", 1), "b": ("b2", 4)}, - ...} - """ - args_dict = {k: list(v.items()) for k, v in args_dict.items()} - keys = list(args_dict.keys()) - values = list(args_dict.values()) - new_values = [dict(zip(keys, v)) for v in itertools.product(*values)] - new_names = [":".join([f"{k}={v[0]}" for k, v in sorted(d.items())]) for d in new_values] - return dict(zip(new_names, new_values)) +from evals.elsuite.modelgraded.base import ModelGradedSpec +from evals.elsuite.modelgraded.classify_utils import ( + CHOICE_KEY, + INVALID_STR, + MATCH_FNS, + concat_n_completions, + get_choice, +) +from evals.elsuite.utils import PromptFn, format_prompt, scrub_formatting_from_prompt class ModelBasedClassify(evals.Eval): @@ -133,89 +71,15 @@ def __init__( else: self.eval_modelspec = ModelSpec(name=eval_model, model=eval_model, is_chat=True) - """import prompt and set attributes""" - modelgraded_specs = self.registry.get_modelgraded_spec(modelgraded_spec) - modelgraded_specs = copy.deepcopy(modelgraded_specs) # since pop() is used - - # 'choice_strings' is a list of strings that specifies the possible choices - self.choice_strings = modelgraded_specs.pop("choice_strings") - if self.choice_strings == "from_n": - self.choice_strings = [str(i + 1) for i in range(self.multicomp_n)] - elif self.choice_strings == "from_n_abc": - self.choice_strings = [string.ascii_lowercase[i % 26] for i in range(self.multicomp_n)] - elif self.choice_strings == "from_n_ABC": - self.choice_strings = [string.ascii_uppercase[i % 26] for i in range(self.multicomp_n)] - # make sure each choice doesn't contain any punctuation - for s in self.choice_strings: - assert not any(c in s for c in string.punctuation), f"{s} contains punctuation" - # (optional) 'choice_scores' is a dict that specifies the score for each choice string - # if 'choice_scores' is specified, 'scores/' are computed and added to metrics - self.choice_scores = modelgraded_specs.pop("choice_scores", {}) - if self.choice_scores == "from_strings": - self.choice_scores = {c: float(c) for c in self.choice_strings} - assert all( - isinstance(v, (int, float)) for v in self.choice_scores.values() - ), f"choice_scores must be a dict of floats, not {self.choice_scores}" - - # (optional) 'eval_type' is a string that specifies the type of classification algorithm - # - "classify": only answer - # - "cot_classify": reason then answer (chain-of-thought) <- most recommended - # - "classify_cot": answer then reason (explanation) - # if 'eval_type' is not supplied from modelgraded_specs, then it must be supplied as an argument. - # - Importantly, it also assumes the answer prompt needs to be appended to the prompt. - self.eval_type = modelgraded_specs.pop("eval_type", None) - if not self.eval_type: - append_answer_prompt = True # append answer prompt to prompt - assert eval_type, "eval_type must be specified, in modelgraded_spec or as an argument" - self.eval_type = eval_type - else: - assert ( - not eval_type - ), f"eval_type must be unspecified, if it is specified in modelgraded_spec" - append_answer_prompt = False - - # 'prompt' is a string that specifies the model-graded evaluation - prompt = modelgraded_specs.pop("prompt") - assert isinstance(prompt, str), f"prompt must be a string, not {type(prompt)}" - if append_answer_prompt: - prompt += "\n\n" + ANSWER_PROMPTS[self.eval_type].format( - choices=choice_to_str(self.choice_strings) - ) - self.prompt = [{"role": "user", "content": prompt}] - - # 'input_outputs' is a dict that specifies the input and output keys in the sample - # output key is the model's raw response to input key. These are used for filling 'prompt' template. - self.input_outputs = modelgraded_specs.pop("input_outputs") - assert isinstance( - self.input_outputs, dict - ), f"input_outputs must be a dict, not {type(self.input_outputs)}" - - # (optional) 'args' is a dict of dicts that specifies additional arguments for 'prompt' - # each value in 'args_dict' essentially defines a separate modelgraded classification eval and has own metrics! - # if 'modelgraded_spec_args' is specified in eval YAML, it is merged with 'args_dict' - self.args_dict = modelgraded_specs.pop("args", {}) - self.args_dict.update(modelgraded_spec_args or {}) - if self.args_dict: - self.expanded_args_dict = expand_args_dict(self.args_dict) - else: - self.expanded_args_dict = {} - - # (optional) 'completion_sample_templates' - # each key must be one of 'input_outputs'.values(). If 'multicomp_n' > 1, this template is filled 'multicomp_n' times - # and the concatenated result is passed to 'prompt' template. - self.completion_sample_templates = modelgraded_specs.pop("completion_sample_templates", {}) - assert all( - k in self.input_outputs.values() for k in self.completion_sample_templates - ), f"all {self.completion_sample_templates.keys()} must be in {self.input_outputs.values()}, " - if self.multicomp_n > 1: - assert ( - self.completion_sample_templates - ), "completion_sample_templates must be specified if multicomp_n > 1" - - # since we accept optional args, we need to check that all args are used - for key in ("key", "group"): - modelgraded_specs.pop(key, None) - assert not modelgraded_specs, f"Unused args: {modelgraded_specs}. Typo in YAML?" + spec_kwargs = {"multicomp_n": self.multicomp_n} + if eval_type: + spec_kwargs["eval_type"] = eval_type + spec_kwargs["append_answer_prompt"] = True # append answer prompt to prompt + if modelgraded_spec_args: + spec_kwargs["args"] = modelgraded_spec_args + self.mg: ModelGradedSpec = self.registry.get_modelgraded_spec( + modelgraded_spec, **spec_kwargs + ) def eval_sample(self, test_sample: dict, rng: Random) -> None: """Evaluate a single sample. @@ -229,22 +93,21 @@ def eval_sample(self, test_sample: dict, rng: Random) -> None: completions = {} if self.metaeval: # assert outputs exist in the data - for v in self.input_outputs.values(): + for v in self.mg.input_outputs.values(): assert v in test_sample, f"Missing output '{v}' in sample {test_sample.keys()}" completions[v] = test_sample[v] # remove outputs from the data test_sample = { - k: v for k, v in test_sample.items() if k not in list(self.input_outputs.values()) + k: v for k, v in test_sample.items() if k not in list(self.mg.input_outputs.values()) } - for k in self.input_outputs: + for k in self.mg.input_outputs: test_sample[k] = scrub_formatting_from_prompt(test_sample[k]) if not self.metaeval: try: - for k, v in self.input_outputs.items(): - if self.multicomp_n > 1 and v in self.completion_sample_templates: - completion = "" - completion_i_template = self.completion_sample_templates[v] + for k, v in self.mg.input_outputs.items(): + if self.multicomp_n > 1 and v in self.mg.completion_sample_templates: + completion_i_s = [] for i in range(self.multicomp_n): if len(self.model_specs.completions) > 1: # use a separate model for each completion @@ -259,14 +122,10 @@ def eval_sample(self, test_sample: dict, rng: Random) -> None: temperature=self.multicomp_temperature, ) completion_i, _ = get_input_completion() - completion += format_necessary( - completion_i_template, - i=i + 1, - i_abc=string.ascii_lowercase[i % 26], - i_ABC=string.ascii_uppercase[i % 26], - output=completion_i, - n=self.multicomp_n, - ) + completion_i_s.append(completion_i) + completion = concat_n_completions( + completion_i_s, self.mg.completion_sample_templates[v] + ) else: get_input_completion = PromptFn( test_sample[k], @@ -279,39 +138,41 @@ def eval_sample(self, test_sample: dict, rng: Random) -> None: self.invalid_request_during_completion += 1 return - try: - metrics = {} + metrics = {} + if self.mg.expanded_args_dict and len(self.mg.expanded_args_dict) > 1: + args_dict = self.mg.expanded_args_dict + elif self.mg.expanded_args_dict and len(self.mg.expanded_args_dict) == 1: + # if there is only one combination, don't bother with the metric name + args_dict = {CHOICE_KEY: v for v in self.mg.expanded_args_dict.values()} + else: + args_dict = {CHOICE_KEY: {}} + for metric, args in args_dict.items(): + args = {k: v[1] for k, v in args.items()} + prompt = format_prompt(self.mg.prompt, **args, **completions, **test_sample) evaluate = PromptFn( - self.prompt, + prompt, model_spec=self.eval_modelspec, max_tokens=self.max_tokens, ) - eval_kwargs = dict(**completions, **test_sample) - if self.expanded_args_dict and len(self.expanded_args_dict) > 1: - args_dict = self.expanded_args_dict - elif self.expanded_args_dict and len(self.expanded_args_dict) == 1: - # if there is only one combination, don't bother with the metric name - args_dict = {CHOICE_KEY: v for v in self.expanded_args_dict.values()} - else: - args_dict = {CHOICE_KEY: {}} - for metric, args in args_dict.items(): - args = {k: v[1] for k, v in args.items()} - evaluation, _ = evaluate(**args, **eval_kwargs) - choice = get_choice(evaluation, self.eval_type, self.match_fn, self.choice_strings) - if choice == INVALID_STR: - logging.warn( - f"Choices {self.choice_strings} not parsable for {self.eval_type}: {evaluation}" - ) - metrics[metric] = choice - if self.metaeval: - assert ( - metric in test_sample - ), f"Missing label for metric '{metric}' in sample {test_sample.keys()}" - metrics[metric + "_metascore"] = choice == test_sample[metric] - - except openai.error.InvalidRequestError: - self.invalid_request_during_evaluation += 1 - return + try: + evaluation, _ = evaluate(skip_format=True) + except openai.error.InvalidRequestError: + logging.warn(f"Invalid request during evaluation: {prompt}") + self.invalid_request_during_evaluation += 1 + return + choice = get_choice( + evaluation, self.mg.eval_type, self.match_fn, self.mg.choice_strings + ) + if choice == INVALID_STR: + logging.warn( + f"Choices {self.mg.choice_strings} not parsable for {self.mg.eval_type}: {evaluation}" + ) + metrics[metric] = choice + if self.metaeval: + assert ( + metric in test_sample + ), f"Missing label for metric '{metric}' in sample {test_sample.keys()}" + metrics[metric + "_metascore"] = choice == test_sample[metric] evals.record.record_metrics(**metrics) @@ -321,21 +182,26 @@ def run(self, recorder): samples = evals.get_jsonl(self.samples_jsonl) self.eval_all_samples(recorder, samples) + record_metrics = {} + record_metrics["invalid_request_during_completion"] = self.invalid_request_during_completion + record_metrics["invalid_request_during_evaluation"] = self.invalid_request_during_evaluation + all_sample_metrics = recorder.get_metrics() + if not all_sample_metrics: + return record_metrics - record_metrics = {} - if self.expanded_args_dict and len(self.expanded_args_dict) > 1: - metrics = sorted(self.expanded_args_dict) + if self.mg.expanded_args_dict and len(self.mg.expanded_args_dict) > 1: + metrics = sorted(self.mg.expanded_args_dict) else: metrics = [CHOICE_KEY] for metric in metrics: chosen = [m[metric] for m in all_sample_metrics if metric in m] # if there is a best choice, compute the score - if self.choice_scores: + if self.mg.choice_scores: # assumption: each INVALID_STR contributes the lowest score - lowest_score = min(self.choice_scores.values()) + lowest_score = min(self.mg.choice_scores.values()) scores = [ - self.choice_scores[choice] if choice != INVALID_STR else lowest_score + self.mg.choice_scores[choice] if choice != INVALID_STR else lowest_score for choice in chosen ] record_metrics[f"score/{metric}"] = sum(scores) / len(all_sample_metrics) @@ -349,7 +215,4 @@ def run(self, recorder): metascores = [m[metric + "_metascore"] for m in all_sample_metrics if metric in m] record_metrics[f"metascore/{metric}"] = sum(metascores) / len(all_sample_metrics) - record_metrics["invalid_request_during_completion"] = self.invalid_request_during_completion - record_metrics["invalid_request_during_evaluation"] = self.invalid_request_during_evaluation - return record_metrics diff --git a/evals/elsuite/modelgraded/classify_utils.py b/evals/elsuite/modelgraded/classify_utils.py new file mode 100644 index 0000000000..9860084b91 --- /dev/null +++ b/evals/elsuite/modelgraded/classify_utils.py @@ -0,0 +1,90 @@ +import itertools +import string +from typing import Callable, Iterable + +from evals.elsuite.utils import format_necessary + +INVALID_STR = "__invalid__" +CHOICE_KEY = "choice" + + +ANSWER_PROMPTS = { + # e.g. "Yes" + "classify": "Answer the question by printing only a single choice from {choices} (without quotes or punctuation) corresponding to the correct answer with no other text.".strip(), + # e.g. "Yes\n The reasons are: ..." + "classify_cot": "First, answer by printing a single choice from {choices} (without quotes or punctuation) corresponding to the correct answer. Then, from the next line, explain your reasonings step by step.".strip(), + # e.g. "Let's think step by step. ...\nYes" + "cot_classify": """ +First, write out in a step by step manner your reasoning to be sure that your conclusion is correct. Avoid simply stating the correct answer at the outset. Then print only a single choice from {choices} (without quotes or punctuation) on its own line corresponding to the correct answer. At the end, repeat just the answer by itself on a new line. + +Reasoning:""".strip(), + "cot_classify_jp": """ +まず、一歩一歩あなたの推論を書き出してください。単に正しい答えを最初に述べることを避けてください。次に、{choices}(引用符や句読点なし)から正しい答えに対応する1つの選択肢を単独の行に書きだしてください。最後に、答えだけを新しい行に繰り返してください。 + +推論: + """.strip(), +} +MATCH_FNS = { + "include": lambda x, y: float(x in y), + "exact": lambda x, y: float(x == y), + "endswith": lambda x, y: x.endswith(y), + "starts_or_endswith": lambda x, y: x.startswith(y) or x.endswith(y), +} + + +def choice_to_str(choice_strings: Iterable[str]) -> str: + """Return a string of choices, e.g. '"Yes" or "No" or "Maybe"'.""" + return " or ".join(f'"{choice}"' for choice in choice_strings) + + +def get_choice(text: str, eval_type: str, match_fn: Callable, choice_strings: Iterable[str]) -> str: + """Clean the answer string to a choice string to one of choice_strings. Return '__invalid__.' if no match.""" + lines = text.strip().split("\n") + if eval_type.startswith("cot_classify"): + lines = lines[::-1] # reverse lines + for line in lines: + line = line.strip() + line = "".join(c for c in line if c not in string.punctuation) + if not line: + continue + for choice in choice_strings: + if match_fn(line, choice): + return choice + return INVALID_STR + + +def concat_n_completions(completions: Iterable[str], template_i: str) -> str: + """Concatenate n completions into a single text string.""" + completion = "" + for i, completion_i in enumerate(completions): + completion += format_necessary( + template_i, + i=i + 1, + i_abc=string.ascii_lowercase[i % 26], + i_ABC=string.ascii_uppercase[i % 26], + output=completion_i, + n=len(completions), + ) + return completion.strip() + + +def expand_args_dict(args_dict): + """Expand a dict of dicts, with namings. + + args_dict = { + "a": {"a1": 1, "a2": 2}, + "b": {"b1": 3, "b2": 4}, + } + expand_args_dict(args_dict) = { + "a=a1:b=b1": {"a": ("a1", 1), "b": ("b1", 3)}, + "a=a1:b=b2": {"a": ("a1", 1), "b": ("b2", 4)}, + ...} + """ + if not args_dict: + return {} + args_dict = {k: list(v.items()) for k, v in args_dict.items()} + keys = list(args_dict.keys()) + values = list(args_dict.values()) + new_values = [dict(zip(keys, v)) for v in itertools.product(*values)] + new_names = [":".join([f"{k}={v[0]}" for k, v in sorted(d.items())]) for d in new_values] + return dict(zip(new_names, new_values)) diff --git a/evals/elsuite/utils.py b/evals/elsuite/utils.py index 615b51ac1c..5a83cd2017 100644 --- a/evals/elsuite/utils.py +++ b/evals/elsuite/utils.py @@ -4,7 +4,7 @@ from collections import Counter, defaultdict from evals.api import sample_freeform -from evals.prompt.base import chat_prompt_to_text_prompt, is_chat_prompt +from evals.prompt.base import OpenAICreatePrompt, chat_prompt_to_text_prompt, is_chat_prompt def get_answer(text, answer_prompt): @@ -93,11 +93,34 @@ def scrub_formatting_from_prompt(prompt): def format_necessary(template: str, **kwargs: dict[str, str]) -> str: """Format a template string with only necessary kwargs.""" keys = [k[1] for k in string.Formatter().parse(template) if k[1]] - assert all(k in kwargs for k in keys), f"Required: {keys}, got: {sorted(kwargs)}" + assert all( + k in kwargs for k in keys + ), f"Required: {keys}, got: {sorted(kwargs)}.\nTemplate:\n{template}" cur_keys = {k: kwargs[k] for k in keys} return template.format(**cur_keys) +def format_prompt(prompt: OpenAICreatePrompt, **kwargs: dict[str, str]) -> OpenAICreatePrompt: + """Format a prompt with only necessary kwargs.""" + # if any input kwargs is chat prompt, convert to text prompt + kwargs = { + k: chat_prompt_to_text_prompt(v, for_completion=False) if is_chat_prompt(v) else v + for k, v in kwargs.items() + } + if is_chat_prompt(prompt): + new_prompt = [] + for msg in prompt: + formatted_msg = copy.copy(msg) + if "content" in formatted_msg: + formatted_msg["content"] = format_necessary(formatted_msg["content"], **kwargs) + new_prompt.append(formatted_msg) + prompt = new_prompt + else: + # Prompt is a string + prompt = format_necessary(prompt, **kwargs) + return prompt + + class PromptFn: """Wrap calls to model with prompt""" @@ -108,25 +131,10 @@ def __init__(self, prompt, model_spec, max_tokens, temperature=0, completion_kwa self.temperature = temperature self.completion_kwargs = completion_kwargs or {} - def __call__(self, **kwargs): - # if any input kwargs is chat prompt, convert to text prompt - kwargs = { - k: chat_prompt_to_text_prompt(v, render_for_completion=False) - if is_chat_prompt(v) - else v - for k, v in kwargs.items() - } - if is_chat_prompt(self.prompt): - prompt = [] - for msg in self.prompt: - formatted_msg = copy.copy(msg) - if "content" in formatted_msg: - formatted_msg["content"] = format_necessary(formatted_msg["content"], **kwargs) - prompt.append(formatted_msg) - else: - # Prompt is a string - prompt = format_necessary(self.prompt, **kwargs) - + def __call__(self, skip_format: bool = False, **kwargs): + prompt = self.prompt + if not skip_format: + prompt = format_prompt(prompt, **kwargs) completion = sample_freeform( self.model_spec, prompt, diff --git a/evals/prompt/base.py b/evals/prompt/base.py index 71946a2e4f..7c2aa3be04 100644 --- a/evals/prompt/base.py +++ b/evals/prompt/base.py @@ -19,9 +19,7 @@ OpenAICreateChatPrompt = List[OpenAIChatMessage] # A chat log is a list of messages -def chat_prompt_to_text_prompt( - prompt: OpenAICreateChatPrompt, render_for_completion: bool = True -) -> str: +def chat_prompt_to_text_prompt(prompt: OpenAICreateChatPrompt, for_completion: bool = True) -> str: """ Render a chat prompt as a text prompt. User and assistant messages are separated by newlines and prefixed with "User: " and "Assistant: ", respectively, unless there is only one message. @@ -46,7 +44,7 @@ def chat_prompt_to_text_prompt( prefix = chat_to_prefixes.get(role, role.capitalize() + ": ") content = msg["content"] text += f"{prefix}{content}\n" - if render_for_completion: + if for_completion: text += "Assistant: " return text.lstrip() diff --git a/evals/registry.py b/evals/registry.py index f55bd65d61..89ce84dc9f 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -4,6 +4,7 @@ By convention, every eval name should start with {base_eval}.{split}. """ +import copy import difflib import functools import logging @@ -16,6 +17,7 @@ import yaml from evals.base import BaseEvalSpec, EvalSetSpec, EvalSpec, ModelSpec +from evals.elsuite.modelgraded.base import ModelGradedSpec from evals.utils.misc import make_object logger = logging.getLogger(__name__) @@ -33,8 +35,14 @@ def make_callable(self, spec): def get_class(self, spec: dict) -> Any: return make_object(spec.cls, **(spec.args if spec.args else {})) - def _dereference(self, name: str, d: dict, object: str, type: Type) -> dict: + def _dereference(self, name: str, d: dict, object: str, type: Type, **kwargs: dict) -> dict: if not name in d: + logger.warning( + ( + f"{object} '{name}' not found. " + f"Closest matches: {difflib.get_close_matches(name, d.keys(), n=5)}" + ) + ) return None def get_alias(): @@ -53,21 +61,22 @@ def get_alias(): name = alias spec = d[name] + if kwargs: + spec = copy.deepcopy(spec) + spec.update(kwargs) try: return type(**spec) except TypeError as e: - raise TypeError(f"Error while processing {object} {name}: {e}") + raise TypeError(f"Error while processing {object} '{name}': {e}") def get_model(self, name: str) -> ModelSpec: return self._dereference(name, self._models, "model", ModelSpec) - def get_modelgraded_spec(self, name: str) -> dict[str, Any]: - assert name in self._modelgraded_specs, ( - f"Modelgraded spec {name} not found. " - f"Closest matches: {difflib.get_close_matches(name, self._modelgraded_specs.keys(), n=5)}" + def get_modelgraded_spec(self, name: str, **kwargs: dict) -> dict[str, Any]: + return self._dereference( + name, self._modelgraded_specs, "modelgraded spec", ModelGradedSpec, **kwargs ) - return self._modelgraded_specs[name] def get_eval(self, name: str) -> EvalSpec: return self._dereference(name, self._evals, "eval", EvalSpec) @@ -177,4 +186,5 @@ def _modelgraded_specs(self): def _models(self): return self._load_registry([p / "models" for p in self._registry_paths]) + registry = Registry() diff --git a/evals/registry/eval_sets/test-modelgraded.yaml b/evals/registry/eval_sets/test-modelgraded.yaml index a86552b7ad..0a87046604 100644 --- a/evals/registry/eval_sets/test-modelgraded.yaml +++ b/evals/registry/eval_sets/test-modelgraded.yaml @@ -11,5 +11,4 @@ test-modelgraded: - joke-animals-vs-fruits - rap-people-vs-people - rap-animals-vs-fruits - - rap-people-vs-fruits - - mg-humor-people_jp + - rap-people-vs-fruits \ No newline at end of file