Skip to content

Commit

Permalink
fix two bugs when ran with qasper_bool and toxigen
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyZwei committed Oct 19, 2023
1 parent ef33202 commit a007bac
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
29 changes: 23 additions & 6 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import re
import json
import fnmatch
import jsonlines
import argparse
import logging
from pathlib import Path

import numpy as np
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING
Expand All @@ -15,6 +14,14 @@
from typing import Union


def _handle_non_serializable(o):
if isinstance(o, np.int64):
return int(o)
elif isinstance(o, set):
return list(o)
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")


def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
Expand Down Expand Up @@ -103,6 +110,12 @@ def parse_eval_args() -> argparse.Namespace:
default="INFO",
help="Log error when tasks are not registered.",
)
parser.add_argument(
"--huggingface_token",
type=str,
default=None,
help="huggingface token for downloading some authorization datasets, like toxigen, https://huggingface.co/settings/tokens",
)
return parser.parse_args()


Expand All @@ -119,7 +132,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.huggingface_token:
from huggingface_hub import login

login(token=args.huggingface_token)
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path)
Expand Down Expand Up @@ -195,7 +211,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o))
dumped = json.dumps(results, indent=2, default=_handle_non_serializable)
if args.show_config:
print(dumped)

Expand All @@ -210,9 +226,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
re.sub("/|=", "__", args.model_args), task_name
)
filename = path.joinpath(f"{output_name}.jsonl")

with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
samples_dumped = json.dumps(
samples[task_name], indent=2, default=_handle_non_serializable
)
filename.open("w").write(samples_dumped)

print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
Expand Down
3 changes: 2 additions & 1 deletion lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

import logging

eval_logger = logging.getLogger('lm-eval')
eval_logger = logging.getLogger("lm-eval")


def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
Expand Down

0 comments on commit a007bac

Please sign in to comment.