Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OCRBench #91

Merged
merged 2 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.distributed as dist
from vlmeval.smp import *
from vlmeval.evaluate import COCO_eval, YOrN_eval, MMVet_eval, multiple_choice_eval, VQAEval, MathVista_eval, LLaVABench_eval
from vlmeval.evaluate import COCO_eval, YOrN_eval, MMVet_eval, multiple_choice_eval, VQAEval, MathVista_eval, LLaVABench_eval, OCRBench_eval
from vlmeval.inference import infer_data_job, prefetch_acc
from vlmeval.config import supported_VLM
from vlmeval.utils import dataset_URLs, DATASET_TYPE, abbr2full
Expand Down Expand Up @@ -86,6 +86,8 @@ def main():
COCO_eval(result_file)
elif dataset_name == 'MMVet':
MMVet_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose)
elif dataset_name == 'OCRBench':
OCRBench_eval(result_file)
elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA', 'DocVQA'], dataset_name):
VQAEval(result_file, dataset_name)
elif listinstr(['MathVista'], dataset_name):
Expand Down
44 changes: 44 additions & 0 deletions vlmeval/evaluate/OCRBench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from vlmeval.smp import *
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
def OCRBench_eval(eval_file):
logger = get_logger('Evaluation')

data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
for i in tqdm(range(len(lines))):
line = lines[i]
predict = str(line['prediction'])
answers = eval(line['answer'])
category = line['category']
if category == "Handwritten Mathematical Expression Recognition":
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
OCRBench_score[category]+= 1
break
else:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
OCRBench_score[category]+= 1
break
final_score_dict = {}
final_score_dict['Text Recognition']=OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
final_score_dict['Handwritten Mathematical Expression Recognition'] = OCRBench_score['Handwritten Mathematical Expression Recognition']
final_score_dict['Final Score'] = final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA'] + final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction'] + final_score_dict['Handwritten Mathematical Expression Recognition']
final_score_dict['Final Score Norm'] = float(final_score_dict['Final Score'])/10
score_pth = eval_file.replace('.xlsx','_score.json')
dump(final_score_dict, score_pth)
logger.info(f'OCRBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
logger.info(f'Score: ')
for key, value in final_score_dict.items():
logger.info('{}:{}'.format(key, value))

3 changes: 2 additions & 1 deletion vlmeval/evaluate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .vqa_eval import VQAEval
from .mathvista_eval import MathVista_eval
from .llavabench import LLaVABench_eval
from .misc import build_judge
from .misc import build_judge
from .OCRBench import OCRBench_eval
7 changes: 5 additions & 2 deletions vlmeval/utils/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv",
'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv",
"LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv",
"OCRBench": 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
}

dataset_md5_dict = {
Expand Down Expand Up @@ -53,7 +54,8 @@
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
"DocVQA_VAL": 'ee0d8ae5527439438d08e154ef65d735',
"AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975",
"LLaVABench": "d382a093f749a697820d3dadd61c8428"
"LLaVABench": "d382a093f749a697820d3dadd61c8428",
"OCRBench": 'e953d98a987cc6e26ef717b61260b778',
}

img_root_map = {k: k for k in dataset_URLs}
Expand All @@ -73,6 +75,7 @@
'ChartQA_VALTEST_HUMAN': 'ChartQA',
'HallusionBench': 'Hallusion',
'DocVQA_VAL': 'DocVQA',
"OCRBench": 'OCRBench',
})

assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict)
Expand All @@ -85,7 +88,7 @@ def DATASET_TYPE(dataset):
return 'Y/N'
elif 'coco' in dataset:
return 'Caption'
elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'llavabench', 'mmvet'], dataset):
elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'llavabench', 'mmvet', 'OCRBench'], dataset):
return 'VQA'
else:
return 'QA'
Expand Down