Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[evals] moved modelgraded specs to registry #392

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[evals] moved modelgraded specs to registry
  • Loading branch information
rlbayes committed Mar 21, 2023
commit 8e7f28128cb53d638124c120335c4adaf3ae1c25
8 changes: 7 additions & 1 deletion evals/cli/oaieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,13 @@ def to_number(x):
extra_eval_params = parse_extra_eval_params(args.extra_eval_params)

eval_class = registry.get_class(eval_spec)
eval = eval_class(model_specs=model_specs, seed=args.seed, name=eval_name, **extra_eval_params)
eval = eval_class(
model_specs=model_specs,
seed=args.seed,
name=eval_name,
registry=registry,
**extra_eval_params,
)
result = eval.run(recorder)
recorder.record_final_report(result)

Expand Down
9 changes: 2 additions & 7 deletions evals/elsuite/modelgraded/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
import evals
import evals.record
from evals.base import ModelSpec
from evals.elsuite.utils import (
PromptFn,
format_necessary,
load_modelgraded_specs,
scrub_formatting_from_prompt,
)
from evals.elsuite.utils import PromptFn, format_necessary, scrub_formatting_from_prompt

INVALID_STR = "__invalid__"
CHOICE_KEY = "choice"
Expand Down Expand Up @@ -135,7 +130,7 @@ def __init__(
)

"""import prompt and set attributes"""
modelgraded_specs = load_modelgraded_specs(modelgraded_spec_file)
modelgraded_specs = self.registry.get_modelgraded_spec(modelgraded_spec_file)

# 'choice_strings' is a list of strings that specifies the possible choices
self.choice_strings = modelgraded_specs.pop("choice_strings")
Expand Down
9 changes: 0 additions & 9 deletions evals/elsuite/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
import copy
import os
import re
import string
from collections import Counter, defaultdict

import yaml

from evals.api import sample_freeform
from evals.prompt.base import chat_prompt_to_text_prompt, is_chat_prompt


def load_modelgraded_specs(spec_file: str) -> str:
current_dir = os.path.dirname(os.path.abspath(__file__))
yaml_path = os.path.join(current_dir, "../registry/modelgraded", f"{spec_file}.yaml")
return yaml.load(open(yaml_path, "r"), Loader=yaml.FullLoader)


def get_answer(text, answer_prompt):
idx = text.rfind(answer_prompt)
if idx == -1:
Expand Down
9 changes: 6 additions & 3 deletions evals/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
"""
import abc
import asyncio
import concurrent.futures
import logging
import os
import random
import concurrent.futures
from multiprocessing.pool import ThreadPool
from typing import Any, Awaitable, Callable, Dict, List, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple

from tqdm import tqdm

from .base import ModelSpec, ModelSpecs
from .record import Recorder, RecorderBase
from .record import RecorderBase
from .registry import Registry

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
model_specs: ModelSpecs,
seed: int = 20220722,
name: str = "no_name_eval.default",
registry: Optional[Registry] = None,
):
splits = name.split(".")
if len(splits) < 2:
Expand All @@ -61,6 +63,7 @@ def __init__(
self.model_specs = model_specs
self.seed = seed
self.name = name
self.registry = registry or Registry()

def eval_sample(self, sample: Any, rng: random.Random):
raise NotImplementedError()
Expand Down
37 changes: 37 additions & 0 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
By convention, every eval name should start with {base_eval}.{split}.
"""

import difflib
rlbayes marked this conversation as resolved.
Show resolved Hide resolved
import functools
import logging
import os
Expand Down Expand Up @@ -58,6 +59,16 @@ def get_alias():
except TypeError as e:
raise TypeError(f"Error while processing {object} {name}: {e}")

def get_modelgraded_spec(self, name: str) -> dict[str, Any]:
assert name in self._modelgraded_specs, (
f"Modelgraded spec {name} not found. "
f"Closest matches: {difflib.get_close_matches(name, self._modelgraded_specs.keys(), n=5)}"
)
path = self._modelgraded_specs[name]
with open(path, "r") as f:
spec = yaml.safe_load(f)
return spec

def get_eval(self, name: str) -> EvalSpec:
return self._dereference(name, self._evals, "eval", EvalSpec)

Expand Down Expand Up @@ -136,6 +147,10 @@ def _process_directory(self, registry, path):
self._process_file(registry, file)

def _load_registry(self, paths):
"""Load registry from a list of paths.

Each path or yaml specifies a dictionary of name -> spec.
"""
registry = {}
for path in paths:
logging.info(f"Loading registry from {path}")
Expand All @@ -146,6 +161,24 @@ def _load_registry(self, paths):
self._process_file(registry, path)
return registry

def _load_registry_paths_only(self, paths, prefixes: Sequence[str] = []):
rlbayes marked this conversation as resolved.
Show resolved Hide resolved
"""Load registry from a list of paths.

Each path or yaml specifies one registry entry.
"""
registry = {}
for path in paths:
for subpath in Path(path).glob("*"):
if os.path.isdir(subpath):
self._load_registry_paths_only(
[subpath], prefixes=prefixes + [os.path.basename(subpath)]
)
else:
name = ".".join(prefixes + [os.path.splitext(os.path.basename(subpath))[0]])
assert name not in registry, f"duplicate entry: {name} from {subpath}"
registry[name] = subpath
return registry

@functools.cached_property
def _eval_sets(self):
return self._load_registry([p / "eval_sets" for p in self._registry_paths])
Expand All @@ -154,5 +187,9 @@ def _eval_sets(self):
def _evals(self):
return self._load_registry([p / "evals" for p in self._registry_paths])

@functools.cached_property
def _modelgraded_specs(self):
return self._load_registry_paths_only([p / "modelgraded" for p in self._registry_paths])


registry = Registry()