Skip to content

Commit

Permalink
Add MMMU evals and runner (openai#1442)
Browse files Browse the repository at this point in the history
## Eval details 📑

### Eval name

MMMU

### Eval description
A multi-modal version of MMLU published here:
https://arxiv.org/pdf/2311.16502.pdf

### What makes this a useful eval?
Tests a variety of subjects, along with image recognition and
comprehension

## Criteria for a good eval ✅

Below are some of the criteria we look for in a good eval. In general,
we are seeking cases where the model does not do a good job despite
being capable of generating a good response (note that there are some
things large language models cannot do, so those would not make good
evals).

Your eval should be:

- [x] Thematically consistent: The eval should be thematically
consistent. We'd like to see a number of prompts all demonstrating some
particular failure mode. For example, we can create an eval on cases
where the model fails to reason about the physical world.
- [x] Contains failures where a human can do the task, but either GPT-4
or GPT-3.5-Turbo could not.
- [x] Includes good signal around what is the right behavior. This means
either a correct answer for `Basic` evals or the `Fact` Model-graded
eval, or an exhaustive rubric for evaluating answers for the `Criteria`
Model-graded eval.
- [x] **Include at least 15 high-quality examples.**

If there is anything else that makes your eval worth including, please
document it below.

### Unique eval value

Multimodal, covers many subjects 

## Eval structure 🏗️

Your eval should

- [x] Check that your YAML is registered at
`evals/registry/evals/{name}.yaml`
- [x] Ensure you have the right to use the data you submit via this eval

### Eval JSON data

Dataset defined here: https://huggingface.co/datasets/MMMU/MMMU

### Eval Results

on `gpt-4-vision-preview`:

```
{
  "mmmu-accounting": 0.5333333333333333,
  "mmmu-agriculture": 0.6333333333333333,
  "mmmu-architecture-and-engineering": 0.16666666666666666,
  "mmmu-art": 0.7333333333333333,
  "mmmu-art-theory": 0.8333333333333334,
  "mmmu-basic-medical-science": 0.6,
  "mmmu-biology": 0.43333333333333335,
  "mmmu-chemistry": 0.43333333333333335,
  "mmmu-clinical-medicine": 0.6333333333333333,
  "mmmu-computer-science": 0.6333333333333333,
  "mmmu-design": 0.7666666666666667,
  "mmmu-diagnostics-and-laboratory-medicine": 0.3,
  "mmmu-economics": 0.6333333333333333,
  "mmmu-electronics": 0.4,
  "mmmu-energy-and-power": 0.36666666666666664,
  "mmmu-finance": 0.43333333333333335,
  "mmmu-geography": 0.4,
  "mmmu-history": 0.6666666666666666,
  "mmmu-literature": 0.9,
  "mmmu-manage": 0.6,
  "mmmu-marketing": 0.6333333333333333,
  "mmmu-materials": 0.26666666666666666,
  "mmmu-math": 0.5,
  "mmmu-mechanical-engineering": 0.23333333333333334,
  "mmmu-music": 0.36666666666666664,
  "mmmu-pharmacy": 0.7666666666666667,
  "mmmu-physics": 0.43333333333333335,
  "mmmu-psychology": 0.7,
  "mmmu-public-health": 0.8,
  "mmmu-sociology": 0.5666666666666667
}
Average accuracy: 0.5455555555555556
```

Note that this is slightly lower than the MMMU paper's findings of
`0.568`. There's likely prompt engineering that could be done to improve
this, but I'll leave that as an exercise for later
  • Loading branch information
etr2460 committed Dec 21, 2023
1 parent d30262c commit f20c305
Show file tree
Hide file tree
Showing 3 changed files with 596 additions and 0 deletions.
174 changes: 174 additions & 0 deletions evals/elsuite/mmmu/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import ast
import base64
import logging
from io import BytesIO
from typing import Optional, Union
from urllib.parse import parse_qs, urlparse

from datasets import load_dataset
from PIL import Image
from pydantic import BaseModel

import evals
import evals.metrics
from evals.api import CompletionFn
from evals.formatting import make_abc
from evals.record import RecorderBase, record_match

logger = logging.getLogger(__name__)


class Sample(BaseModel):
question: str
answers: list[str]
label: Union[int, str]
question_type: str
image_1: Optional[Image.Image]
image_2: Optional[Image.Image]
image_3: Optional[Image.Image]
image_4: Optional[Image.Image]
image_5: Optional[Image.Image]
image_6: Optional[Image.Image]
image_7: Optional[Image.Image]

class Config:
arbitrary_types_allowed = True


def get_dataset(url: str) -> list[Sample]:
parsed = urlparse(url)
query = parse_qs(parsed.query)
query = {k: v[0] for k, v in query.items()}

dataset = load_dataset("mmmu/mmmu", **query)

return [
Sample(
question=sample["question"],
answers=ast.literal_eval(sample["options"]),
label=(
ord(sample["answer"]) - ord("A")
if sample["question_type"] == "multiple-choice"
else sample["answer"]
),
question_type=sample["question_type"],
image_1=sample["image_1"],
image_2=sample["image_2"],
image_3=sample["image_3"],
image_4=sample["image_4"],
image_5=sample["image_5"],
image_6=sample["image_6"],
image_7=sample["image_7"],
)
for sample in dataset
]


class MMMU(evals.Eval):
def __init__(
self,
completion_fns: list[CompletionFn],
dataset: str,
subject: str,
*args,
**kwargs,
):
super().__init__(completion_fns, *args, **kwargs)
assert len(completion_fns) == 1, "MMMU only supports one completion fn"
self.dataset = dataset
self.subject = subject

def eval_sample(self, sample: Sample, rng):
assert isinstance(sample, Sample)

if sample.question_type == "multiple-choice":
options, correct_answer = make_abc(
answers=sample.answers,
correct_idx=sample.label,
rng=rng,
)
prompt = sample.question + "\n" + options
system_prompt = f'You are an expert in {self.subject} whose job is to answer questions from the user using images. First, reason about the correct answer. Then write the answer in the following format where X is exactly one of A,B,C,D: "ANSWER: X"'
else:
correct_answer = sample.label
prompt = sample.question
system_prompt = f'You are an expert in {self.subject} whose job is to answer questions from the user using images. First, reason about the correct answer. Then write the answer in the following format where X is only the answer and nothing else: "ANSWER: X"'

images = [
image
for image in [
sample.image_1,
sample.image_2,
sample.image_3,
sample.image_4,
sample.image_5,
sample.image_6,
sample.image_7,
]
if image is not None
]

base_64_images = []
for image in images:
buffer = BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue())
base_64_images.append(img_str.decode())

try:
result = self.completion_fn(
prompt=[
{
"role": "system",
"content": [
{
"type": "text",
"text": system_prompt,
},
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base_64_image}",
},
}
for base_64_image in base_64_images
],
},
],
temperature=0.0,
max_tokens=4096,
)
sampled = result.get_completions()[0]
except Exception as e:
logging.info("Sampling failed!")
logging.info(sample)
logging.info(f"Prompt: {prompt}")
logging.info(f"Error: {str(e)}")
sampled = "ERROR: " + str(e)

match = sampled.find(f"ANSWER: {correct_answer}") != -1

record_match(
match,
expected=correct_answer,
picked=(correct_answer if match else None),
sampled=sampled,
)

def run(self, recorder: RecorderBase):
samples = get_dataset(self.dataset)
self.eval_all_samples(recorder, samples)
return {
"accuracy": evals.metrics.get_accuracy(recorder.get_events("match")),
}
3 changes: 3 additions & 0 deletions evals/registry/eval_sets/mmmu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mmmu:
evals:
- mmmu-*.validation.v1
Loading

0 comments on commit f20c305

Please sign in to comment.