Skip to content

Commit

Permalink
prevent parsing failures from winning maj. vote
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Oct 16, 2023
1 parent a1432cc commit db96c31
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions lm_eval/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def majority_vote(
self,
sampled_answers: List[T],
correct_answer: T,
is_equiv : Callable[[T, T], bool] = lambda x, y: x==y
is_equiv : Callable[[T, T], bool] = lambda x, y: x==y,
invalid_answer: T = None
):
"""
Performs majority voting on a list of candidate answers.
Expand All @@ -47,6 +48,9 @@ def majority_vote(
correct_answer: T, ground truth.
is_equiv: Callable[[T, T], bool], a function that determines when two answers
should be treated as equivalent. Default is T-equivalence, i.e `lambda x y: x==y`.
invalid_answer: T, answer that corresponds to a parsing failure from a sample.
If passed as arg, no votes for invalid answer should be counted, but it should
count against pass_rate.
Returns:
acc: int, 0/1 for correct/incorrect
pass_rate: float, proportion of `sampled_answers` equivalent to `correct_answer`
Expand All @@ -57,7 +61,17 @@ def majority_vote(
return 0, 0, []

answer_votes = {}
for answer in sampled_answers:

# we only count votes for successfully parsed answers, as we choose not
# to allow a model to vote for [invalidanswer] as its response.
# however, we do want to calculate pass_rate as a function of
# total K = *num. sampled answers*.
if invalid_answer:
valid_sampled_answers = [answer for answer in sampled_answers if answer != invalid_answer]
else:
valid_sampled_answers = sampled_answers

for answer in valid_sampled_answers:
if answer in answer_votes:
answer_votes[answer] += 1
else:
Expand Down

0 comments on commit db96c31

Please sign in to comment.