Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize metrics #1167

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e7cd7d6
sample metrics that have both sample-wise and set-wise operations
lintangsutawika Dec 19, 2023
1d262a5
change how metrics are registered
lintangsutawika Dec 19, 2023
028f04c
loglikelihood and loglikelihood rolling modified
lintangsutawika Dec 19, 2023
6117c50
changed how metrics are calculated
lintangsutawika Dec 19, 2023
a808c66
Merge branch 'main' of https://github.com/EleutherAI/lm-evaluation-ha…
lintangsutawika Dec 19, 2023
c6a9158
update
lintangsutawika Dec 27, 2023
4d49dd0
aggregation to compute_metric
lintangsutawika Dec 28, 2023
9d6bc92
aggregation to compute_metric
lintangsutawika Dec 28, 2023
3888193
simplify registry
lintangsutawika Dec 28, 2023
039832e
removed passthrough fn
lintangsutawika Dec 28, 2023
e5b245c
remove aggregation
lintangsutawika Dec 28, 2023
20c10df
kwargs are added to metric_fn through partial at the beginning
lintangsutawika Dec 28, 2023
6a336b1
use HFEvaluateAdaptor for hf metrics
lintangsutawika Dec 28, 2023
150f11f
revert to just load metric_fn
lintangsutawika Dec 28, 2023
99ce4ef
process hf evaluate metrics
lintangsutawika Dec 28, 2023
439dca5
list tuple for string based multigpu collection
lintangsutawika Dec 29, 2023
aaf64aa
readded suport for aggregation
lintangsutawika Jan 2, 2024
787b23f
readd aggregation
lintangsutawika Jan 2, 2024
703e0d5
adjusted aggregation config
lintangsutawika Jan 2, 2024
2a573a1
adjust to be backwards compatible
lintangsutawika Jan 2, 2024
2054c2e
revert
lintangsutawika Jan 2, 2024
dfb4183
revert
lintangsutawika Jan 2, 2024
cda25fe
Merge branch 'main' into standardize_metrics
lintangsutawika Jan 2, 2024
470fb31
resolved git conflict
lintangsutawika Jan 2, 2024
dfb036b
resolved again
lintangsutawika Jan 2, 2024
de46fb9
reformat
lintangsutawika Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 53 additions & 127 deletions lm_eval/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,44 @@ def median(arr):
return arr[len(arr) // 2]


# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity")
@register_aggregation("weighted_mean")
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)


@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
)
def perplexity(items):
return math.exp(-mean(items))


@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
@register_metric(
metric=["word_perplexity", "byte_perplexity"],
higher_is_better=False,
output_type="loglikelihood_rolling",
)
def weighted_perplexity(items): # This is a passthrough function
return math.exp(-weighted_mean(items))


@register_aggregation("bits_per_byte")
@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
)
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)


@register_aggregation("f1")
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
)
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
Expand All @@ -52,16 +72,23 @@ def f1_score(items):
return np.max(fscore)


@register_aggregation("matthews_corrcoef")
@register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
)
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
# print(preds)
return sklearn.metrics.matthews_corrcoef(golds, preds)


@register_aggregation("bleu")
@register_metric(
metric="bleu",
higher_is_better=True,
output_type="generate_until",
)
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
Expand All @@ -79,7 +106,11 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score


@register_aggregation("chrf")
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="generate_until",
)
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Expand All @@ -94,7 +125,11 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score


@register_aggregation("ter")
@register_metric(
metric="ter",
higher_is_better=True,
output_type="generate_until",
)
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
Expand All @@ -111,33 +146,21 @@ def ter(items):


@register_metric(
metric="acc",
metric=["acc", "acc_norm"],
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
return items


@register_metric(
metric="acc_norm",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_norm_fn(items): # This is a passthrough function
return items
def aggregate_acc_fn(items):
return mean(items)


@register_metric(
metric="acc_mutual_info",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def acc_mutual_info_fn(items): # This is a passthrough function
return items
def acc_mutual_info_fn(items):
return mean(items)


exact_match = evaluate.load("exact_match")
Expand All @@ -147,52 +170,11 @@ def acc_mutual_info_fn(items): # This is a passthrough function
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
def hf_evaluate_fn(**kwargs):
return exact_match.compute(**kwargs)


@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
aggregation="perplexity",
)
def perplexity_fn(items): # This is a passthrough function
return items


@register_metric(
metric="word_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
return items


@register_metric(
metric="byte_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
return items


@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
return items


def pop_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
Expand All @@ -207,61 +189,10 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))


@register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
aggregation="matthews_corrcoef",
)
def mcc_fn(items): # This is a passthrough function
return items


@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
aggregation="f1",
)
def f1_fn(items): # This is a passthrough function
return items


@register_metric(
metric="bleu",
higher_is_better=True,
output_type="generate_until",
aggregation="bleu",
)
def bleu_fn(items): # This is a passthrough function
return items


@register_metric(
metric="chrf",
higher_is_better=True,
output_type="generate_until",
aggregation="chrf",
)
def chrf_fn(items): # This is a passthrough function
return items


@register_metric(
metric="ter",
higher_is_better=True,
output_type="generate_until",
aggregation="ter",
)
def ter_fn(items): # This is a passthrough function
return items


@register_metric(
metric="acc_all",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
Expand Down Expand Up @@ -309,11 +240,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)


def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)


def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)

Expand Down
Loading
Loading