Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text recognition (OCR) #9

Merged
merged 16 commits into from
Feb 12, 2024
Merged
Prev Previous commit
Next Next commit
Fix slice pad
  • Loading branch information
VikParuchuri committed Feb 2, 2024
commit 07a715cac9d22aa4af64183dfcd1a12b0c7a6e78
7 changes: 6 additions & 1 deletion benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from surya.ocr import run_ocr, run_recognition
from surya.postprocessing.text import draw_text_on_image
from surya.settings import settings
from surya.languages import CODE_TO_LANGUAGE, is_arabic
import arabic_reshaper
import os
import datasets
import json
Expand Down Expand Up @@ -46,9 +48,12 @@ def main():

image_scores = defaultdict(list)
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
if any(is_arabic(l) for l in lang):
ref_text = [arabic_reshaper.reshape(t) for t in ref_text]
pred["text_lines"] = [arabic_reshaper.reshape(t) for t in pred["text_lines"]]
image_score = overlap_score(pred["text_lines"], ref_text)
for l in lang:
image_scores[l].append(image_score)
image_scores[CODE_TO_LANGUAGE[l]].append(image_score)

image_avgs = {l: sum(scores) / len(scores) for l, scores in image_scores.items()}
print(image_avgs)
Expand Down
8 changes: 8 additions & 0 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from surya.settings import settings
from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE
import os


Expand All @@ -23,7 +24,14 @@ def main():
parser.add_argument("--lang", type=str, help="Language to use for OCR. Comma separate for multiple.", default="en")
args = parser.parse_args()

# Split and validate language codes
langs = args.lang.split(",")
for i in range(len(langs)):
if langs[i] in LANGUAGE_TO_CODE:
langs[i] = LANGUAGE_TO_CODE[langs[i]]
if langs[i] not in CODE_TO_LANGUAGE:
raise ValueError(f"Language code {langs[i]} not found.")

det_processor = load_detection_processor()
det_model = load_detection_model()

Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pymupdf = "^1.23.8"
snakeviz = "^2.2.0"
datasets = "^2.16.1"
rapidfuzz = "^3.6.1"
arabic-reshaper = "^3.0.0"

[tool.poetry.scripts]
surya_detect = "detect_text:main"
Expand Down
32 changes: 24 additions & 8 deletions surya/input/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,31 @@ def slice_polys_from_image(image: Image.Image, polys):


def slice_and_pad_poly(image: Image.Image, coordinates):
coordinates = [(corner[0], corner[1]) for corner in coordinates]

# Create a mask for the polygon
mask = Image.new('L', image.size, 0)

# coordinates must be in tuple form for PIL
coordinates = [(corner[0], corner[1]) for corner in coordinates]
ImageDraw.Draw(mask).polygon(coordinates, outline=1, fill=1)
bbox = mask.getbbox()
mask = mask.crop(bbox)
cropped_image = image.crop(bbox)
mask = mask.convert('1')
rectangle = Image.new('RGB', cropped_image.size, 'white')
rectangle.paste(cropped_image, (0, 0), mask)
mask = np.array(mask)

# Extract the polygonal area from the image
polygon_image = np.array(image)
polygon_image[~mask] = 0
polygon_image = Image.fromarray(polygon_image)

bbox_image = Image.new('L', image.size, 0)
ImageDraw.Draw(bbox_image).polygon(coordinates, outline=1, fill=1)
bbox = bbox_image.getbbox()

rectangle = Image.new('RGB', (bbox[2] - bbox[0], bbox[3] - bbox[1]), 'white')

# Paste the polygon into the rectangle
polygon_center = (bbox[2] + bbox[0]) // 2, (bbox[3] + bbox[1]) // 2
rectangle_center = rectangle.width // 2, rectangle.height // 2
paste_position = (rectangle_center[0] - polygon_center[0] + bbox[0],
rectangle_center[1] - polygon_center[1] + bbox[1])
rectangle.paste(polygon_image.crop(bbox), paste_position)

return rectangle

101 changes: 101 additions & 0 deletions surya/languages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
CODE_TO_LANGUAGE = {
'af': 'Afrikaans',
'am': 'Amharic',
'ar': 'Arabic',
'as': 'Assamese',
'az': 'Azerbaijani',
'be': 'Belarusian',
'bg': 'Bulgarian',
'bn': 'Bangla',
'br': 'Breton',
'bs': 'Bosnian',
'ca': 'Catalan',
'cs': 'Czech',
'cy': 'Welsh',
'da': 'Danish',
'de': 'German',
'el': 'Greek',
'en': 'English',
'eo': 'Esperanto',
'es': 'Spanish',
'et': 'Estonian',
'eu': 'Basque',
'fa': 'Persian',
'fi': 'Finnish',
'fr': 'French',
'fy': 'Western Frisian',
'ga': 'Irish',
'gd': 'Scottish Gaelic',
'gl': 'Galician',
'gu': 'Gujarati',
'ha': 'Hausa',
'he': 'Hebrew',
'hi': 'Hindi',
'hr': 'Croatian',
'hu': 'Hungarian',
'hy': 'Armenian',
'id': 'Indonesian',
'is': 'Icelandic',
'it': 'Italian',
'ja': 'Japanese',
'jv': 'Javanese',
'ka': 'Georgian',
'kk': 'Kazakh',
'km': 'Khmer',
'kn': 'Kannada',
'ko': 'Korean',
'ku': 'Kurdish',
'ky': 'Kyrgyz',
'la': 'Latin',
'lo': 'Lao',
'lt': 'Lithuanian',
'lv': 'Latvian',
'mg': 'Malagasy',
'mk': 'Macedonian',
'ml': 'Malayalam',
'mn': 'Mongolian',
'mr': 'Marathi',
'ms': 'Malay',
'my': 'Burmese',
'ne': 'Nepali',
'nl': 'Dutch',
'no': 'Norwegian',
'om': 'Oromo',
'or': 'Odia',
'pa': 'Punjabi',
'pl': 'Polish',
'ps': 'Pashto',
'pt': 'Portuguese',
'ro': 'Romanian',
'ru': 'Russian',
'sa': 'Sanskrit',
'sd': 'Sindhi',
'si': 'Sinhala',
'sk': 'Slovak',
'sl': 'Slovenian',
'so': 'Somali',
'sq': 'Albanian',
'sr': 'Serbian',
'su': 'Sundanese',
'sv': 'Swedish',
'sw': 'Swahili',
'ta': 'Tamil',
'te': 'Telugu',
'th': 'Thai',
'tl': 'Tagalog',
'tr': 'Turkish',
'ug': 'Uyghur',
'uk': 'Ukrainian',
'ur': 'Urdu',
'uz': 'Uzbek',
'vi': 'Vietnamese',
'xh': 'Xhosa',
'yi': 'Yiddish',
'zh': 'Chinese'
}

LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}


def is_arabic(lang_code):
return lang_code in ["ar", "fa", "ps", "ug", "ur"]
4 changes: 2 additions & 2 deletions surya/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
slice_map = []
all_slices = []
all_langs = []
for idx, (image, det_pred, lang) in tqdm(enumerate(zip(images, det_predictions, langs)), desc="Slicing images"):
for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)):
slices = slice_polys_from_image(image, det_pred["polygons"])
slice_map.append(len(slices))
all_slices.extend(slices)
Expand All @@ -80,4 +80,4 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
"language": lang
})

return predictions_by_image
return predictions_by_image
12 changes: 8 additions & 4 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ def get_batch_size():

def batch_recognition(images: List, languages: List[List[str]], model, processor):
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(languages)
batch_size = get_batch_size()

images = [image.convert("RGB") for image in images]
model_inputs = processor(text=[""] * len(languages), images=images, lang=languages)

output_text = []
for i in tqdm(range(0, len(model_inputs["pixel_values"]), batch_size), desc="Recognizing Text"):
batch_langs = model_inputs["langs"][i:i+batch_size]
batch_pixel_values = model_inputs["pixel_values"][i:i+batch_size]
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
batch_langs = languages[i:i+batch_size]
batch_images = images[i:i+batch_size]
model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs)

batch_pixel_values = model_inputs["pixel_values"]
batch_langs = model_inputs["langs"]
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]

batch_langs = torch.from_numpy(np.array(batch_langs, dtype=np.int64)).to(model.device)
Expand Down
Loading