Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text recognition (OCR) #9

Merged
merged 16 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ jobs:
poetry install
poetry remove torch
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Run benchmark test
- name: Run detection benchmark test
run: |
poetry run python benchmark/detection.py --max 2
poetry run python scripts/verify_benchmark_scores.py results/benchmark/doclaynet_bench/results.json
poetry run python scripts/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
- name: Run recognition benchmark test
run: |
poetry run python benchmark/recognition.py --max 2
poetry run python scripts/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition



1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ wandb
notebooks
results
data
slices

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
185 changes: 138 additions & 47 deletions README.md

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from surya.benchmark.bbox import get_pdf_lines
from surya.benchmark.metrics import precision_recall
from surya.benchmark.tesseract import tesseract_bboxes, tesseract_parallel
from surya.model.segformer import load_model, load_processor
from surya.model.processing import open_pdf, get_page_images
from surya.detection import batch_inference
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
from surya.benchmark.tesseract import tesseract_parallel
from surya.model.detection.segformer import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images
from surya.detection import batch_detection
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.util import rescale_bbox
from surya.settings import settings
import os
Expand Down Expand Up @@ -42,9 +42,9 @@ def main():
image_sizes = [img.size for img in images]
correct_boxes = get_pdf_lines(args.pdf_path, image_sizes)
else:
pathname = "doclaynet_bench"
pathname = "det_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(settings.BENCH_DATASET_NAME, split=f"train[:{args.max}]")
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
images = list(dataset["image"])
images = [i.convert("RGB") for i in images]
correct_boxes = []
Expand All @@ -54,7 +54,7 @@ def main():
correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])

start = time.time()
predictions = batch_inference(images, model, processor)
predictions = batch_detection(images, model, processor)
surya_time = time.time() - start

start = time.time()
Expand Down
2 changes: 1 addition & 1 deletion benchmark/pymupdf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from surya.benchmark.bbox import get_pdf_lines
from surya.postprocessing.heatmap import draw_bboxes_on_image

from surya.model.processing import open_pdf, get_page_images
from surya.input.processing import open_pdf, get_page_images
from surya.settings import settings


Expand Down
159 changes: 159 additions & 0 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import argparse
from collections import defaultdict

from benchmark.scoring import overlap_score
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.ocr import run_recognition
from surya.postprocessing.text import draw_text_on_image
from surya.settings import settings
from surya.languages import CODE_TO_LANGUAGE
from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE
import os
import datasets
import json
import time
from tabulate import tabulate

KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"]


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None)
parser.add_argument("--debug", type=int, help="Debug level - 1 dumps bad detection info, 2 writes out images.", default=0)
parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False)
parser.add_argument("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
parser.add_argument("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
args = parser.parse_args()

rec_model = load_recognition_model()
rec_processor = load_recognition_processor()

split = "train"
if args.max:
split = f"train[:{args.max}]"

dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split)

if args.langs:
langs = args.langs.split(",")
dataset = dataset.filter(lambda x: x["language"] in langs)

images = list(dataset["image"])
images = [i.convert("RGB") for i in images]
bboxes = list(dataset["bboxes"])
line_text = list(dataset["text"])
languages = list(dataset["language"])

print(f"Loaded {len(images)} images. Running OCR...")

lang_list = []
for l in languages:
if not isinstance(l, list):
lang_list.append([l])
else:
lang_list.append(l)

start = time.time()
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)
surya_time = time.time() - start

surya_scores = defaultdict(list)
img_surya_scores = []
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
image_score = overlap_score(pred["text_lines"], ref_text)
img_surya_scores.append(image_score)
for l in lang:
surya_scores[CODE_TO_LANGUAGE[l]].append(image_score)

flat_surya_scores = [s for l in surya_scores for s in surya_scores[l]]
benchmark_stats = {
"surya": {
"avg_score": sum(flat_surya_scores) / len(flat_surya_scores),
"lang_scores": {l: sum(scores) / len(scores) for l, scores in surya_scores.items()},
"time_per_img": surya_time / len(images)
}
}

result_path = os.path.join(args.results_dir, "rec_bench")
os.makedirs(result_path, exist_ok=True)

with open(os.path.join(result_path, "surya_scores.json"), "w+") as f:
json.dump(surya_scores, f)

if args.tesseract:
tess_valid = []
tess_langs = []
for idx, lang in enumerate(lang_list):
# Tesseract does not support all languages
tess_lang = surya_lang_to_tesseract(lang[0])
if tess_lang is None:
continue

tess_valid.append(idx)
tess_langs.append(tess_lang)

tess_imgs = [images[i] for i in tess_valid]
tess_bboxes = [bboxes[i] for i in tess_valid]
tess_reference = [line_text[i] for i in tess_valid]
start = time.time()
tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs, cpus=args.tess_cpus)
tesseract_time = time.time() - start

tess_scores = defaultdict(list)
for idx, (pred, ref_text, lang) in enumerate(zip(tess_predictions, tess_reference, tess_langs)):
image_score = overlap_score(pred, ref_text)
tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)

flat_tess_scores = [s for l in tess_scores for s in tess_scores[l]]
benchmark_stats["tesseract"] = {
"avg_score": sum(flat_tess_scores) / len(flat_tess_scores),
"lang_scores": {l: sum(scores) / len(scores) for l, scores in tess_scores.items()},
"time_per_img": tesseract_time / len(tess_imgs)
}

with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f:
json.dump(tess_scores, f)

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(benchmark_stats, f)

key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]
table_headers = ["Model", "Time per page (s)", "Avg Score"] + KEY_LANGUAGES
table_data = [
["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages],
]
if args.tesseract:
table_data.append(
["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages]
)

print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print("Only a few major languages are displayed. See the result path for additional languages.")

if args.debug >= 1:
bad_detections = []
for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)):
if score < .8:
bad_detections.append((idx, lang, score))
print(f"Found {len(bad_detections)} bad detections. Writing to file...")
with open(os.path.join(result_path, "bad_detections.json"), "w+") as f:
json.dump(bad_detections, f)

if args.debug == 2:
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
pred_image = draw_text_on_image(bbox, pred["text_lines"], image.size)
pred_image.save(os.path.join(result_path, pred_image_name))
ref_image = draw_text_on_image(bbox, ref_text, image.size)
ref_image.save(os.path.join(result_path, ref_image_name))
image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))

print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()
22 changes: 22 additions & 0 deletions benchmark/scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import math
from typing import List

from rapidfuzz import fuzz


def overlap_score(pred_lines: List[str], reference_lines: List[str]):
line_scores = []
line_weights = []
for i, pred_line in enumerate(pred_lines):
max_score = 0
line_weight = 1
for j, ref_line in enumerate(reference_lines):
score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
if score > max_score:
max_score = score
line_weight = math.sqrt(len(ref_line))
line_scores.append(max_score)
line_weights.append(line_weight)
line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))]

return sum(line_scores) / sum(line_weights)
2 changes: 1 addition & 1 deletion benchmark/tesseract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from surya.benchmark.tesseract import tesseract_bboxes
from surya.postprocessing.heatmap import draw_bboxes_on_image

from surya.model.processing import open_pdf, get_page_images
from surya.input.processing import open_pdf, get_page_images
from surya.settings import settings


Expand Down
66 changes: 7 additions & 59 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,71 +3,19 @@
import json
from collections import defaultdict

from PIL import Image

from surya.model.segformer import load_model, load_processor
from surya.model.processing import open_pdf, get_page_images
from surya.detection import batch_inference
from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.segformer import load_model, load_processor
from surya.detection import batch_detection
from surya.postprocessing.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os
import filetype


def get_name_from_path(path):
return os.path.basename(path).split(".")[0]


def load_pdf(pdf_path, max_pages=None):
doc = open_pdf(pdf_path)
page_count = len(doc)
if max_pages:
page_count = min(max_pages, page_count)

page_indices = list(range(page_count))

images = get_page_images(doc, page_indices)
doc.close()
names = [get_name_from_path(pdf_path) for _ in page_indices]
return images, names


def load_image(image_path):
image = Image.open(image_path).convert("RGB")
name = get_name_from_path(image_path)
return [image], [name]


def load_from_file(input_path, max_pages=None):
input_type = filetype.guess(input_path)
if input_type.extension == "pdf":
return load_pdf(input_path, max_pages)
else:
return load_image(input_path)


def load_from_folder(folder_path, max_pages=None):
image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path)]
image_paths = [ip for ip in image_paths if not os.path.isdir(ip) and not ip.startswith(".")]

images = []
names = []
for path in image_paths:
if filetype.guess(path).extension == "pdf":
image, name = load_pdf(path, max_pages)
images.extend(image)
names.extend(name)
else:
image, name = load_image(path)
images.extend(image)
names.extend(name)
return images, names
from tqdm import tqdm


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file to detect bboxes in.")
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
Expand All @@ -84,7 +32,7 @@ def main():
images, names = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

predictions = batch_inference(images, model, processor)
predictions = batch_detection(images, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

Expand Down
Loading
Loading