Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text recognition (OCR) #9

Merged
merged 16 commits into from
Feb 12, 2024
Merged
Prev Previous commit
Next Next commit
Improve benchmarks
  • Loading branch information
VikParuchuri committed Feb 2, 2024
commit 260fc0c09d71ec628582dfd5f75320111e6e0df0
51 changes: 34 additions & 17 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,74 @@
from collections import defaultdict

from benchmark.scoring import overlap_score
from surya.detection import batch_detection
from surya.input.processing import slice_polys_from_image
from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
from surya.ocr import run_ocr
from surya.recognition import batch_recognition
from surya.ocr import run_ocr, run_recognition
from surya.postprocessing.text import draw_text_on_image
from surya.settings import settings
import os
import time
from tabulate import tabulate
import datasets
import json


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
parser.add_argument("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None)
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=100)
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()

det_model = load_detection_model()
det_processor = load_detection_processor()
rec_model = load_recognition_model() # Prune model moes to only include languages we need
rec_model = load_recognition_model()
rec_processor = load_recognition_processor()

dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
split = "train"
if args.max:
split = f"train[:{args.max}]"

dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = [i.convert("RGB") for i in images]
bboxes = list(dataset["bboxes"])
line_text = list(dataset["text"])
languages = list(dataset["language"])

print(f"Loaded {len(images)} images. Running OCR...")

lang_list = []
for l in languages:
if not isinstance(l, list):
lang_list.append([l])
else:
lang_list.append(l)

predictions_by_image = run_ocr(images, lang_list, det_model, det_processor, rec_model, rec_processor)
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)

image_scores = defaultdict(list)
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, languages)):
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
image_score = overlap_score(pred["text_lines"], ref_text)
for l in lang:
image_scores[l].append(image_score)

print(image_scores)
image_avgs = {l: sum(scores) / len(scores) for l, scores in image_scores.items()}
print(image_avgs)

result_path = os.path.join(args.results_dir, "rec_bench")
os.makedirs(result_path, exist_ok=True)

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(image_scores, f)

if args.debug:
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
pred_image = draw_text_on_image(bbox, pred["text_lines"], image.size)
pred_image.save(os.path.join(result_path, pred_image_name))
ref_image = draw_text_on_image(bbox, ref_text, image.size)
ref_image.save(os.path.join(result_path, ref_image_name))
image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))

print(f"Wrote results to {result_path}")


if __name__ == "__main__":
Expand Down
Loading
Loading