Skip to content

Commit

Permalink
[evals] minor refactoring cli (openai#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlbayes committed Mar 21, 2023
1 parent f118fca commit 3c718fc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
12 changes: 7 additions & 5 deletions evals/cli/oaieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import evals.base
import evals.record
from evals.base import ModelSpec, ModelSpecs
from evals.registry import registry
from evals.registry import Registry

logger = logging.getLogger(__name__)

Expand All @@ -24,7 +24,7 @@ def _purple(str):
return f"\033[1;35m{str}\033[0m"


def parse_args(args=sys.argv[1:]) -> argparse.Namespace:
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run evals through the API")
parser.add_argument("model", type=str, help="Name of a completion model.")
parser.add_argument("eval", type=str, help="Name of an eval. See registry.")
Expand All @@ -44,7 +44,7 @@ def parse_args(args=sys.argv[1:]) -> argparse.Namespace:
parser.add_argument("--local-run", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--dry-run", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--dry-run-logging", action=argparse.BooleanOptionalAction, default=True)
return parser.parse_args(args)
return parser


def n_ctx_from_model_name(model_name: str) -> Optional[int]:
Expand Down Expand Up @@ -122,7 +122,7 @@ def api_model_ids(self):
return [m["id"] for m in openai.Model.list()["data"]]


def run(args):
def run(args, registry: Optional[Registry] = None):
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)

Expand All @@ -131,6 +131,7 @@ def run(args):
if args.max_samples is not None:
evals.eval.set_max_samples(args.max_samples)

registry = registry or Registry()
eval_spec = registry.get_eval(args.eval)
assert (
eval_spec is not None
Expand Down Expand Up @@ -224,7 +225,8 @@ def to_number(x):


def main():
args = parse_args()
parser = get_parser()
args = parser.parse_args(sys.argv[1:])
logging.basicConfig(
format="[%(asctime)s] [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
Expand Down
15 changes: 12 additions & 3 deletions evals/cli/oaievalset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import subprocess
from pathlib import Path
from typing import Optional

from evals.registry import Registry

Expand Down Expand Up @@ -41,7 +42,7 @@ def highlight(str: str) -> str:
return f"\033[1;32m>>> {str}\033[0m"


def main() -> None:
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run eval sets through the API")
parser.add_argument("model", type=str, help="Name of a completion model.")
parser.add_argument("eval_set", type=str, help="Name of eval set. See registry.")
Expand All @@ -57,9 +58,11 @@ def main() -> None:
default=True,
help="Exit if any oaieval command fails.",
)
args, unknown_args = parser.parse_known_args()
return parser

registry = Registry()

def run(args, unknown_args, registry: Optional[Registry] = None) -> None:
registry = registry or Registry()
commands: list[Task] = []
eval_set = registry.get_eval_set(args.eval_set)
for eval in registry.get_evals(eval_set.evals):
Expand Down Expand Up @@ -92,5 +95,11 @@ def main() -> None:
print(highlight("All done!"))


def main() -> None:
parser = get_parser()
args, unknown_args = parser.parse_known_args()
run(args, unknown_args)


if __name__ == "__main__":
main()
4 changes: 0 additions & 4 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@

DEFAULT_PATHS = [Path(__file__).parents[0].resolve() / "registry", Path.home() / ".evals"]

DEFAULT_SYSTEM_PATHS = [
Path(__file__).parents[0].resolve() / "registry",
]


class Registry:
def __init__(self, registry_paths: Sequence[Union[str, Path]] = DEFAULT_PATHS):
Expand Down

0 comments on commit 3c718fc

Please sign in to comment.