diff --git a/lm_eval/mixins.py b/lm_eval/mixins.py index ac2e292986..61de932102 100644 --- a/lm_eval/mixins.py +++ b/lm_eval/mixins.py @@ -35,7 +35,8 @@ def majority_vote( self, sampled_answers: List[T], correct_answer: T, - is_equiv : Callable[[T, T], bool] = lambda x, y: x==y + is_equiv : Callable[[T, T], bool] = lambda x, y: x==y, + invalid_answer: T = None ): """ Performs majority voting on a list of candidate answers. @@ -47,6 +48,9 @@ def majority_vote( correct_answer: T, ground truth. is_equiv: Callable[[T, T], bool], a function that determines when two answers should be treated as equivalent. Default is T-equivalence, i.e `lambda x y: x==y`. + invalid_answer: T, answer that corresponds to a parsing failure from a sample. + If passed as arg, no votes for invalid answer should be counted, but it should + count against pass_rate. Returns: acc: int, 0/1 for correct/incorrect pass_rate: float, proportion of `sampled_answers` equivalent to `correct_answer` @@ -57,7 +61,17 @@ def majority_vote( return 0, 0, [] answer_votes = {} - for answer in sampled_answers: + + # we only count votes for successfully parsed answers, as we choose not + # to allow a model to vote for [invalidanswer] as its response. + # however, we do want to calculate pass_rate as a function of + # total K = *num. sampled answers*. + if invalid_answer: + valid_sampled_answers = [answer for answer in sampled_answers if answer != invalid_answer] + else: + valid_sampled_answers = sampled_answers + + for answer in valid_sampled_answers: if answer in answer_votes: answer_votes[answer] += 1 else: