Skip to content

Commit

Permalink
add_rec_sar, test=dygraph
Browse files Browse the repository at this point in the history
  • Loading branch information
andyjiang1116 committed Aug 24, 2021
1 parent ffa9441 commit 8a95b33
Show file tree
Hide file tree
Showing 15 changed files with 207 additions and 13 deletions.
3 changes: 2 additions & 1 deletion doc/doc_ch/algorithm_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] STAR-Net([paper](http:https://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))

参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:

Expand All @@ -58,6 +59,6 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |

|SAR|Resnet31| 87.1% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |

PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)
1 change: 1 addition & 0 deletions doc/doc_ch/recognition.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |

训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:

Expand Down
2 changes: 2 additions & 0 deletions doc/doc_en/algorithm_overview_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list:
- [x] STAR-Net([paper](http:https://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))

Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:

Expand All @@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|SAR|Resnet31| 87.1% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |

Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
1 change: 1 addition & 0 deletions doc/doc_en/recognition_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |


For training Chinese data, it is recommended to use
Expand Down
2 changes: 1 addition & 1 deletion ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop

from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SARRecResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import *
Expand Down
46 changes: 46 additions & 0 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,49 @@ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx


class SARLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SARLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)

def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1

return dict_character

def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]

padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data

def get_ignored_tokens(self):
return [self.padding_idx]
50 changes: 50 additions & 0 deletions ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,56 @@ def __call__(self, data):
return data


class SARRecResizeImg(object):
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio

def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio)
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
return data


def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape

return padding_im, resize_shape, pad_shape, valid_ratio


def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
Expand Down
3 changes: 2 additions & 1 deletion ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
from .rec_sar_loss import SARLoss

# cls loss
from .cls_loss import ClsLoss
Expand All @@ -44,7 +45,7 @@
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'SARLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
Expand Down
3 changes: 2 additions & 1 deletion ppocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_resnet_31 import ResNet31
support_dict = [
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", "ResNet31"
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
Expand Down
3 changes: 2 additions & 1 deletion ppocr/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def build_head(config):
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
from .rec_sar_head import SARHead

# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'TableAttentionHead']
'SRNHead', 'PGHead', 'TableAttentionHead', 'SARHead']

#table head
from .table_att_head import TableAttentionHead
Expand Down
4 changes: 2 additions & 2 deletions ppocr/postprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode
TableLabelDecode, SARLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess

Expand All @@ -35,7 +35,7 @@ def build_post_process(config, global_config=None):
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess'
'DistillationDBPostProcess', 'SARLabelDecode'
]

config = copy.deepcopy(config)
Expand Down
77 changes: 77 additions & 0 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import string
import paddle
from paddle.nn import functional as F
import re


class BaseRecLabelDecode(object):
Expand Down Expand Up @@ -454,3 +455,79 @@ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx


class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """

def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)

def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character

def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()

batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(self.end_idx):
if text_prob is None and idx ==0:
continue
else:
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower()
text = comp.sub('', text)
result_list.append((text, np.mean(conf_list)))
return result_list

def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)

text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)

if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label

def get_ignored_tokens(self):
return [self.padding_idx]
3 changes: 2 additions & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def main():

model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
use_sar = config['Architecture']['algorithm'] == "SAR"
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
Expand All @@ -71,7 +72,7 @@ def main():

# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn)
eval_class, model_type, use_srn, use_sar)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
Expand Down
9 changes: 9 additions & 0 deletions tools/infer_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def main():
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [
'image', 'valid_ratio'
]
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
Expand Down Expand Up @@ -106,11 +110,16 @@ def main():
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
if config['Architecture']['algorithm'] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]

images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
elif config['Architecture']['algorithm'] == "SAR":
preds = model(images, img_metas)
else:
preds = model(images)
post_result = post_process_class(preds)
Expand Down
Loading

0 comments on commit 8a95b33

Please sign in to comment.