Darwin Bautista and Rowel Atienza
Electrical and Electronics Engineering Institute
University of the Philippines, Diliman
Method | Sample Results | Getting Started | FAQ | Training | Evaluation | Citation
Scene Text Recognition (STR) models use language context to be more robust against noisy or corrupted images. Recent approaches like ABINet use a standalone or external Language Model (LM) for prediction refinement. In this work, we show that the external LM—which requires upfront allocation of dedicated compute capacity—is inefficient for STR due to its poor performance vs cost characteristics. We propose a more efficient approach using permuted autoregressive sequence (PARSeq) models. View our ECCV poster and presentation for a brief overview.
NOTE: P-S and P-Ti are shorthands for PARSeq-S and PARSeq-Ti, respectively.
Our main insight is that with an ensemble of autoregressive (AR) models, we could unify the current STR decoding methods (context-aware AR and context-free non-AR) and the bidirectional (cloze) refinement model:
A single Transformer can realize different models by merely varying its attention masks. This characteristic coupled with Permutation Language Modeling allows for a unified STR model capable of context-free and context-aware inference, as well as iterative prediction refinement using bidirectional context without requiring a standalone language model. PARSeq can be considered an ensemble of AR models with shared architecture and weights:
NOTE: Bold letters and underscores indicate wrong and missing character predictions, respectively.
This repository contains the reference implementation for PARSeq and reproduced models (collectively referred to as Scene Text Recognition Model Hub). See NOTICE
for copyright information.
Majority of the code is licensed under the Apache License v2.0 (see LICENSE
) while ABINet and CRNN sources are
released under the BSD and MIT licenses, respectively (see corresponding LICENSE
files for details).
An interactive Gradio demo hosted at Hugging Face is available. The pretrained weights released here are used for the demo.
Requires Python 3.7 and PyTorch 1.10 or newer. Tested on Python 3.9 and PyTorch 1.10.
$ pip install -r requirements.txt
$ pip install -e .
Download the datasets from the following links:
- LMDB archives for MJSynth, SynthText, IIIT5k, SVT, SVTP, IC13, IC15, CUTE80, ArT, RCTW17, ReCTS, LSVT, MLT19, COCO-Text, and Uber-Text.
- LMDB archives for TextOCR and OpenVINO.
Available models are: abinet
, crnn
, trba
, vitstr
, parseq_tiny
, and parseq
.
import torch
from PIL import Image
from strhub.data.module import SceneTextDataModule
# Load model and image transforms
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
img_transform = SceneTextDataModule.get_transform(parseq.hparams.img_size)
img = Image.open('/path/to/image.png').convert('RGB')
# Preprocess. Model expects a batch of images with shape: (B, C, H, W)
img = img_transform(img).unsqueeze(0)
logits = parseq(img)
logits.shape # torch.Size([1, 26, 95]), 94 characters + [EOS] symbol
# Greedy decoding
pred = logits.softmax(-1)
label, confidence = parseq.tokenizer.decode(pred)
print('Decoded label = {}'.format(label[0]))
- How do I train on a new language? See Issues #5 and #9.
- Can you export to TorchScript or ONNX? Yes, see Issue #12.
- How do I test on my own dataset? See Issue #27.
- How do I finetune and/or create a custom dataset? See Issue #7.
- What is
val_NED
? See Issue #10.
The training script can train any supported model. You can override any configuration using the command line. Please refer to Hydra docs for more info about the syntax. Use ./train.py --help
to see the default configuration.
Sample commands for different training configurations
./train.py pretrained=parseq-tiny # Not all experiments have pretrained weights
The base model configurations are in configs/model/
, while variations are stored in configs/experiment/
.
./train.py +experiment=parseq-tiny # Some examples: abinet-sv, trbc
./train.py charset=94_full # Other options: 36_lowercase or 62_mixed-case. See configs/charset/
./train.py dataset=real # Other option: synth. See configs/dataset/
./train.py model.img_size=[32, 128] model.max_label_length=25 model.batch_size=384
./train.py data.root_dir=data data.num_workers=2 data.augment=true
./train.py trainer.max_epochs=20 trainer.gpus=2 +trainer.accelerator=gpu
Note that you can pass any Trainer parameter,
you just need to prefix it with +
if it is not originally specified in configs/main.yaml
.
./train.py +experiment=<model_exp> ckpt_path=outputs/<model>/<timestamp>/checkpoints/<checkpoint>.ckpt
The test script, test.py
, can be used to evaluate any model trained with this project. For more info, see ./test.py --help
.
PARSeq runtime parameters can be passed using the format param:type=value
. For example, PARSeq NAR decoding can be invoked via ./test.py parseq.ckpt refine_iters:int=2 decode_ar:bool=false
.
Sample commands for reproducing results
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt # or use the released weights: ./test.py pretrained=parseq
Sample output:
Dataset | # samples | Accuracy | 1 - NED | Confidence | Label Length |
---|---|---|---|---|---|
IIIT5k | 3000 | 99.00 | 99.79 | 97.09 | 5.09 |
SVT | 647 | 97.84 | 99.54 | 95.87 | 5.86 |
IC13_1015 | 1015 | 98.13 | 99.43 | 97.19 | 5.31 |
IC15_2077 | 2077 | 89.22 | 96.43 | 91.91 | 5.33 |
SVTP | 645 | 96.90 | 99.36 | 94.37 | 5.86 |
CUTE80 | 288 | 98.61 | 99.80 | 96.43 | 5.53 |
Combined | 7672 | 95.95 | 98.78 | 95.34 | 5.33 |
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt # lowercase alphanumeric (36-character set)
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --cased # mixed-case alphanumeric (62-character set)
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --cased --punctuation # mixed-case alphanumeric + punctuation (94-character set)
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --new
./bench.py model=parseq model.decode_ar=false model.refine_iters=3
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8fcae67ee0>
model(x)
Median: 14.87 ms
IQR: 0.33 ms (14.78 to 15.12)
7 measurements, 10 runs per measurement, 1 thread
| module | #parameters | #flops | #activations |
|:----------------------|:--------------|:---------|:---------------|
| model | 23.833M | 3.255G | 8.214M |
| encoder | 21.381M | 2.88G | 7.127M |
| decoder | 2.368M | 0.371G | 1.078M |
| head | 36.575K | 3.794M | 9.88K |
| text_embed.embedding | 37.248K | 0 | 0 |
./bench.py model=parseq model.decode_ar=false model.refine_iters=3 +range=true
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --cased --punctuation # no rotation
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --cased --punctuation --rotation 90
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --cased --punctuation --rotation 180
./test.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --cased --punctuation --rotation 270
./read.py outputs/<model>/<timestamp>/checkpoints/last.ckpt --images demo_images/* # Or use ./read.py pretrained=parseq
Additional keyword arguments: {}
demo_images/art-01107.jpg: CHEWBACCA
demo_images/coco-1166773.jpg: Chevrol
demo_images/cute-184.jpg: SALMON
demo_images/ic13_word_256.png: Verbandsteffe
demo_images/ic15_word_26.png: Kaopa
demo_images/uber-27491.jpg: 3rdAve
# use NAR decoding + 2 refinement iterations for PARSeq
./read.py pretrained=parseq refine_iters:int=2 decode_ar:bool=false --images demo_images/*
We use Ray Tune for automated parameter tuning of the learning rate. See ./tune.py --help
. Extend tune.py
to support tuning of other hyperparameters.
./tune.py tune.num_samples=20 # find optimum LR for PARSeq's default config using 20 trials
./tune.py +experiment=tune_abinet-lm # find the optimum learning rate for ABINet's language model
@InProceedings{bautista2022parseq,
title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
author={Bautista, Darwin and Atienza, Rowel},
booktitle={European Conference on Computer Vision},
pages={178--196},
month={10},
year={2022},
publisher={Springer Nature Switzerland},
address={Cham},
doi={10.1007/978-3-031-19815-1_11},
url={https://doi.org/10.1007/978-3-031-19815-1_11}
}