Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize metrics #1167

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e7cd7d6
sample metrics that have both sample-wise and set-wise operations
lintangsutawika Dec 19, 2023
1d262a5
change how metrics are registered
lintangsutawika Dec 19, 2023
028f04c
loglikelihood and loglikelihood rolling modified
lintangsutawika Dec 19, 2023
6117c50
changed how metrics are calculated
lintangsutawika Dec 19, 2023
a808c66
Merge branch 'main' of https://github.com/EleutherAI/lm-evaluation-ha…
lintangsutawika Dec 19, 2023
c6a9158
update
lintangsutawika Dec 27, 2023
4d49dd0
aggregation to compute_metric
lintangsutawika Dec 28, 2023
9d6bc92
aggregation to compute_metric
lintangsutawika Dec 28, 2023
3888193
simplify registry
lintangsutawika Dec 28, 2023
039832e
removed passthrough fn
lintangsutawika Dec 28, 2023
e5b245c
remove aggregation
lintangsutawika Dec 28, 2023
20c10df
kwargs are added to metric_fn through partial at the beginning
lintangsutawika Dec 28, 2023
6a336b1
use HFEvaluateAdaptor for hf metrics
lintangsutawika Dec 28, 2023
150f11f
revert to just load metric_fn
lintangsutawika Dec 28, 2023
99ce4ef
process hf evaluate metrics
lintangsutawika Dec 28, 2023
439dca5
list tuple for string based multigpu collection
lintangsutawika Dec 29, 2023
aaf64aa
readded suport for aggregation
lintangsutawika Jan 2, 2024
787b23f
readd aggregation
lintangsutawika Jan 2, 2024
703e0d5
adjusted aggregation config
lintangsutawika Jan 2, 2024
2a573a1
adjust to be backwards compatible
lintangsutawika Jan 2, 2024
2054c2e
revert
lintangsutawika Jan 2, 2024
dfb4183
revert
lintangsutawika Jan 2, 2024
cda25fe
Merge branch 'main' into standardize_metrics
lintangsutawika Jan 2, 2024
470fb31
resolved git conflict
lintangsutawika Jan 2, 2024
dfb036b
resolved again
lintangsutawika Jan 2, 2024
de46fb9
reformat
lintangsutawika Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use HFEvaluateAdaptor for hf metrics
  • Loading branch information
lintangsutawika committed Dec 28, 2023
commit 6a336b154f62018a1de075f24b2e4762f02c1b15
18 changes: 16 additions & 2 deletions lm_eval/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,29 @@ def acc_mutual_info_fn(items):
return mean(items)


exact_match = evaluate.load("exact_match")
class HFEvaluateAdaptor:
def __init__(self, *metric_args, **kwargs):

metric_object = evaluate.load(*metric_args)
self.hf_evaluate_fn = partial(metric_object, **kwargs)

def __call__(self, items):
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]

return self.hf_evaluate_fn(
references=refs,
predictions=preds
)

exact_match = evaluate.load("exact_match")

@register_metric(
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
)
def exact_match_fn(**kwargs):
def hf_evaluate_fn(**kwargs):
return exact_match.compute(**kwargs)


Expand Down
8 changes: 4 additions & 4 deletions lm_eval/api/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import evaluate
from lm_eval.api.model import LM

from lm_eval.api.metrics import HFEvaluateAdaptor
import logging

eval_logger = logging.getLogger("lm-eval")
Expand Down Expand Up @@ -115,7 +115,7 @@ def decorate(fn):
return decorate


def get_metric(name, hf_evaluate_metric=False):
def get_metric(name, hf_evaluate_metric=False, **kwargs):

if not hf_evaluate_metric:
if name in METRIC_FUNCTION_REGISTRY:
Expand All @@ -126,8 +126,8 @@ def get_metric(name, hf_evaluate_metric=False):
)

try:
metric_object = evaluate.load(name)
return metric_object.compute
from lm_eval.metrics import HFEvaluateAdaptor
return HFEvaluateAdaptor(name, **kwargs)
except Exception:
eval_logger.error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
Expand Down
8 changes: 4 additions & 4 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable
from functools import partial

from lm_eval import utils
from lm_eval.api import samplers
Expand Down Expand Up @@ -588,11 +587,11 @@ def __init__(
metric_name = metric_name.__name__
else:
metric_fn = get_metric(
metric_name, hf_evaluate_metric
metric_name, hf_evaluate_metric, **kwargs
)

self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = partial(metric_fn, **kwargs) if kwargs != {} else metric_fn
self._metric_fn_list[metric_name] = metric_fn

self.download(self.config.dataset_kwargs)
self._training_docs = None
Expand Down Expand Up @@ -1106,6 +1105,8 @@ def process_results(self, doc, results):
gold = type(result)(gold)

for metric in self._metric_fn_list.keys():
result_dict[metric] = (gold, result)
continue
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
Expand Down Expand Up @@ -1141,7 +1142,6 @@ def process_results(self, doc, results):
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
Expand Down