From c1fd46641eff86b8e7ee2aa8d1943f7aced5b3a7 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 30 Dec 2020 16:15:49 +0800 Subject: [PATCH 01/10] add srn for dygraph --- configs/rec/rec_mv3_none_bilstm_ctc.yml | 6 +- configs/rec/rec_mv3_none_none_ctc.yml | 4 +- configs/rec/rec_mv3_tps_bilstm_ctc.yml | 4 +- configs/rec/rec_r34_vd_none_bilstm_ctc.yml | 4 +- configs/rec/rec_r34_vd_none_none_ctc.yml | 4 +- configs/rec/rec_r34_vd_tps_bilstm_ctc.yml | 4 +- configs/rec/rec_r50_fpn_srn.yml | 106 ++++++ ppocr/data/__init__.py | 4 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 48 +++ ppocr/data/imaug/rec_img_aug.py | 90 ++++- ppocr/data/lmdb_dataset.py | 4 +- ppocr/losses/__init__.py | 5 +- ppocr/losses/rec_srn_loss.py | 47 +++ ppocr/metrics/__init__.py | 1 + ppocr/metrics/rec_metric.py | 4 +- ppocr/modeling/architectures/base_model.py | 7 +- ppocr/modeling/backbones/__init__.py | 3 +- ppocr/modeling/backbones/rec_resnet_fpn.py | 307 ++++++++++++++++ ppocr/modeling/heads/__init__.py | 5 +- ppocr/modeling/heads/rec_srn_head.py | 279 ++++++++++++++ ppocr/modeling/heads/self_attention.py | 408 +++++++++++++++++++++ ppocr/postprocess/__init__.py | 5 +- ppocr/postprocess/rec_postprocess.py | 84 ++++- tools/export_model.py | 41 ++- tools/infer/predict_rec.py | 149 +++++++- tools/infer_rec.py | 25 +- tools/program.py | 14 +- 28 files changed, 1594 insertions(+), 70 deletions(-) create mode 100644 configs/rec/rec_r50_fpn_srn.yml create mode 100644 ppocr/losses/rec_srn_loss.py create mode 100644 ppocr/modeling/backbones/rec_resnet_fpn.py create mode 100644 ppocr/modeling/heads/rec_srn_head.py create mode 100644 ppocr/modeling/heads/self_attention.py diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 38f1e8691e..00c1db885e 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,5 +1,5 @@ Global: - use_gpu: true + use_gpu: True epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -59,7 +59,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -78,7 +78,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index 33079ad48c..6711b1d23f 100644 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -58,7 +58,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -77,7 +77,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 08f68939d4..1b9fb0a08d 100644 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -63,7 +63,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -82,7 +82,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 4ad2ff89ef..e4d301a6a1 100644 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -58,7 +58,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -77,7 +77,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index 9c1eeb304f..4a17a00422 100644 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -56,7 +56,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -75,7 +75,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index aeded4926a..62edf84379 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -62,7 +62,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -81,7 +81,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml new file mode 100644 index 0000000000..78f8d55102 --- /dev/null +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -0,0 +1,106 @@ +Global: + use_gpu: True + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/rec/srn + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 5000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + num_heads: 8 + infer_mode: False + use_space_char: False + + +Optimizer: + name: Adam + lr: + name: Cosine + learning_rate: 0.0001 + +Architecture: + model_type: rec + algorithm: SRN + in_channels: 1 + Transform: + Backbone: + name: ResNetFPN + Head: + name: SRNHead + max_text_length: 25 + num_heads: 8 + num_encoder_TUs: 2 + num_decoder_TUs: 4 + hidden_dims: 512 + +Loss: + name: SRNLoss + +PostProcess: + name: SRNLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/srn_train_data_duiqi + #label_file_list: ["./train_data/ic15_data/1.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SRNLabelEncode: # Class handling label + - SRNRecResizeImg: + image_shape: [1, 64, 256] + - KeepKeys: + keep_keys: ['image', + 'label', + 'length', + 'encoder_word_pos', + 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', + 'gsrm_slf_attn_bias2'] # dataloader will return list in this order + loader: + shuffle: False + batch_size_per_card: 64 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SRNLabelEncode: # Class handling label + - SRNRecResizeImg: + image_shape: [1, 64, 256] + - KeepKeys: + keep_keys: ['image', + 'label', + 'length', + 'encoder_word_pos', + 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', + 'gsrm_slf_attn_bias2'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 32 + num_workers: 4 diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 7b0faf1260..eb461ffa03 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -33,7 +33,7 @@ from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet -from ppocr.data.lmdb_dataset import LMDBDateSet +from ppocr.data.lmdb_dataset import LMDBDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -54,7 +54,7 @@ def term_mp(sig_num, frame): def build_dataloader(config, mode, device, logger): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDateSet'] + support_dict = ['SimpleDataSet', 'LMDBDataSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 6ea4dd8ed6..250ac75e76 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .randaugment import RandAugment from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index af3308a553..986cec3dba 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -98,6 +98,8 @@ def __init__(self, support_character_type, character_type) self.max_text_len = max_text_length + self.beg_str = "sos" + self.end_str = "eos" if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -213,3 +215,49 @@ def get_beg_end_flag_idx(self, beg_or_end): assert False, "Unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class SRNLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length=25, + character_dict_path=None, + character_type='en', + use_space_char=False, + **kwargs): + super(SRNLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + char_num = len(self.character_str) + if text is None: + return None + if len(text) > self.max_text_len: + return None + data['length'] = np.array(len(text)) + text = text + [char_num] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 2ccb2d1d2b..28e6bd0bce 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -12,20 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import math import cv2 import numpy as np @@ -77,6 +63,26 @@ def __call__(self, data): return data +class SRNRecResizeImg(object): + def __init__(self, image_shape, num_heads, max_text_length, **kwargs): + self.image_shape = image_shape + self.num_heads = num_heads + self.max_text_length = max_text_length + + def __call__(self, data): + img = data['image'] + norm_img = resize_norm_img_srn(img, self.image_shape) + data['image'] = norm_img + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length) + + data['encoder_word_pos'] = encoder_word_pos + data['gsrm_word_pos'] = gsrm_word_pos + data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1 + data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2 + return data + + def resize_norm_img(img, image_shape): imgC, imgH, imgW = image_shape h = img.shape[0] @@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape): def resize_norm_img_chinese(img, image_shape): imgC, imgH, imgW = image_shape # todo: change to 0 and modified image shape - max_wh_ratio = 0 + max_wh_ratio = imgW * 1.0 / imgH h, w = img.shape[0], img.shape[1] ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, ratio) @@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape): return padding_im +def resize_norm_img_srn(img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + +def srn_other_inputs(image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, + [num_heads, 1, 1]) * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, + [num_heads, 1, 1]) * [-1e9] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def flag(): """ flag diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e7bb6dd3c9..515279fb73 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -20,9 +20,9 @@ from .imaug import transform, create_operators -class LMDBDateSet(Dataset): +class LMDBDataSet(Dataset): def __init__(self, config, mode, logger): - super(LMDBDateSet, self).__init__() + super(LMDBDataSet, self).__init__() global_config = config['Global'] dataset_config = config[mode]['dataset'] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 4673d35cec..b280eb333e 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -23,11 +23,14 @@ def build_loss(config): # rec loss from .rec_ctc_loss import CTCLoss + from .rec_srn_loss import SRNLoss # cls loss from .cls_loss import ClsLoss - support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] + support_dict = [ + 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss' + ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py new file mode 100644 index 0000000000..d722ee0f22 --- /dev/null +++ b/ppocr/losses/rec_srn_loss.py @@ -0,0 +1,47 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class SRNLoss(nn.Layer): + def __init__(self, **kwargs): + super(SRNLoss, self).__init__() + self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum") + + def forward(self, predicts, batch): + predict = predicts['predict'] + word_predict = predicts['word_out'] + gsrm_predict = predicts['gsrm_out'] + label = batch[1] + + casted_label = paddle.cast(x=label, dtype='int64') + casted_label = paddle.reshape(x=casted_label, shape=[-1, 1]) + + cost_word = self.loss_func(word_predict, label=casted_label) + cost_gsrm = self.loss_func(gsrm_predict, label=casted_label) + cost_vsfd = self.loss_func(predict, label=casted_label) + + cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1]) + cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) + cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) + + sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 + + return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index a0e7d91207..41828f510a 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,6 +26,7 @@ def build_metric(config): from .det_metric import DetMetric from .rec_metric import RecMetric from .cls_metric import ClsMetric + from .rec_metric import RecMetric support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index bd0f92e0d7..459fe8e403 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -31,8 +31,6 @@ def __call__(self, pred_label, *args, **kwargs): if pred == target: correct_num += 1 all_num += 1 - # if all_num < 10 and kwargs.get('show_str', False): - # print('{} -> {}'.format(pred, target)) self.correct_num += correct_num self.all_num += all_num self.norm_edit_dis += norm_edit_dis @@ -48,7 +46,7 @@ def get_metric(self): 'norm_edit_dis': 0, } """ - acc = self.correct_num / self.all_num + acc = 1.0 * self.correct_num / self.all_num norm_edit_dis = 1 - self.norm_edit_dis / self.all_num self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index ab44b53a2b..09b6e0346d 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -68,11 +68,14 @@ def __init__(self, config): config["Head"]['in_channels'] = in_channels self.head = build_head(config["Head"]) - def forward(self, x): + def forward(self, x, data=None): if self.use_transform: x = self.transform(x) x = self.backbone(x) if self.use_neck: x = self.neck(x) - x = self.head(x) + if data is None: + x = self.head(x) + else: + x = self.head(x, data) return x diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 43103e53d2..03c15508a5 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -24,7 +24,8 @@ def build_backbone(config, model_type): elif model_type == 'rec' or model_type == 'cls': from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet - support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN'] + from .rec_resnet_fpn import ResNetFPN + support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py new file mode 100644 index 0000000000..a7e876a2bd --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_fpn.py @@ -0,0 +1,307 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import paddle +import numpy as np + +__all__ = ["ResNetFPN"] + + +class ResNetFPN(nn.Layer): + def __init__(self, in_channels=1, layers=50, **kwargs): + super(ResNetFPN, self).__init__() + supported_layers = { + 18: { + 'depth': [2, 2, 2, 2], + 'block_class': BasicBlock + }, + 34: { + 'depth': [3, 4, 6, 3], + 'block_class': BasicBlock + }, + 50: { + 'depth': [3, 4, 6, 3], + 'block_class': BottleneckBlock + }, + 101: { + 'depth': [3, 4, 23, 3], + 'block_class': BottleneckBlock + }, + 152: { + 'depth': [3, 8, 36, 3], + 'block_class': BottleneckBlock + } + } + stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)] + num_filters = [64, 128, 256, 512] + self.depth = supported_layers[layers]['depth'] + self.F = [] + self.conv = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=7, + stride=2, + act="relu", + name="conv1") + self.block_list = [] + in_ch = 64 + if layers >= 50: + for block in range(len(self.depth)): + for i in range(self.depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + block_list = self.add_sublayer( + "bottleneckBlock_{}_{}".format(block, i), + BottleneckBlock( + in_channels=in_ch, + out_channels=num_filters[block], + stride=stride_list[block] if i == 0 else 1, + name=conv_name)) + in_ch = num_filters[block] * 4 + self.block_list.append(block_list) + self.F.append(block_list) + else: + for block in range(len(self.depth)): + for i in range(self.depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + basic_block = self.add_sublayer( + conv_name, + BasicBlock( + in_channels=in_ch, + out_channels=num_filters[block], + stride=stride_list[block] if i == 0 else 1, + is_first=block == i == 0, + name=conv_name)) + in_ch = basic_block.out_channels + self.block_list.append(basic_block) + out_ch_list = [in_ch // 4, in_ch // 2, in_ch] + self.base_block = [] + self.conv_trans = [] + self.bn_block = [] + for i in [-2, -3]: + in_channels = out_ch_list[i + 1] + out_ch_list[i] + + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_0".format(i), + nn.Conv2D( + in_channels=in_channels, + out_channels=out_ch_list[i], + kernel_size=1, + weight_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_1".format(i), + nn.Conv2D( + in_channels=out_ch_list[i], + out_channels=out_ch_list[i], + kernel_size=3, + padding=1, + weight_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_2".format(i), + nn.BatchNorm( + num_channels=out_ch_list[i], + act="relu", + param_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_3".format(i), + nn.Conv2D( + in_channels=out_ch_list[i], + out_channels=512, + kernel_size=1, + bias_attr=ParamAttr(trainable=True), + weight_attr=ParamAttr(trainable=True)))) + self.out_channels = 512 + + def __call__(self, x): + x = self.conv(x) + fpn_list = [] + F = [] + for i in range(len(self.depth)): + fpn_list.append(np.sum(self.depth[:i + 1])) + + for i, block in enumerate(self.block_list): + x = block(x) + for number in fpn_list: + if i + 1 == number: + F.append(x) + base = F[-1] + + j = 0 + for i, block in enumerate(self.base_block): + if i % 3 == 0 and i < 6: + j = j + 1 + b, c, w, h = F[-j - 1].shape + if [w, h] == list(base.shape[2:]): + base = base + else: + base = self.conv_trans[j - 1](base) + base = self.bn_block[j - 1](base) + base = paddle.concat([base, F[-j - 1]], axis=1) + base = block(base) + return base + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 if stride == (1, 1) else kernel_size, + dilation=2 if stride == (1, 1) else 1, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'), + bias_attr=False, ) + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name=name + '.output.1.w_0'), + bias_attr=ParamAttr(name=name + '.output.1.b_0'), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") + + def __call__(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ShortCut(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name, is_first=False): + super(ShortCut, self).__init__() + self.use_conv = True + + if in_channels != out_channels or stride != 1 or is_first == True: + if stride == (1, 1): + self.conv = ConvBNLayer( + in_channels, out_channels, 1, 1, name=name) + else: # stride==(2,2) + self.conv = ConvBNLayer( + in_channels, out_channels, 1, stride, name=name) + else: + self.use_conv = False + + def forward(self, x): + if self.use_conv: + x = self.conv(x) + return x + + +class BottleneckBlock(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name): + super(BottleneckBlock, self).__init__() + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + self.short = ShortCut( + in_channels=in_channels, + out_channels=out_channels * 4, + stride=stride, + is_first=False, + name=name + "_branch1") + self.out_channels = out_channels * 4 + + def forward(self, x): + y = self.conv0(x) + y = self.conv1(y) + y = self.conv2(y) + y = y + self.short(x) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name, is_first): + super(BasicBlock, self).__init__() + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + self.short = ShortCut( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + is_first=is_first, + name=name + "_branch1") + self.out_channels = out_channels + + def forward(self, x): + y = self.conv0(x) + y = self.conv1(y) + y = y + self.short(x) + return F.relu(y) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 7807470905..1a39ca412a 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,10 +23,13 @@ def build_head(config): # rec head from .rec_ctc_head import CTCHead + from .rec_srn_head import SRNHead # cls head from .cls_head import ClsHead - support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead'] + support_dict = [ + 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead' + ] module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py new file mode 100644 index 0000000000..8aaf65e1ae --- /dev/null +++ b/ppocr/modeling/heads/rec_srn_head.py @@ -0,0 +1,279 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import numpy as np +from .self_attention import WrapEncoderForFeature +from .self_attention import WrapEncoder +from paddle.static import Program +from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN +import paddle.fluid.framework as framework + +from collections import OrderedDict +gradient_clip = 10 + + +class PVAM(nn.Layer): + def __init__(self, in_channels, char_num, max_text_length, num_heads, + num_encoder_tus, hidden_dims): + super(PVAM, self).__init__() + self.char_num = char_num + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_tus + self.hidden_dims = hidden_dims + # Transformer encoder + t = 256 + c = 512 + self.wrap_encoder_for_feature = WrapEncoderForFeature( + src_vocab_size=1, + max_length=t, + n_layer=self.num_encoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + # PVAM + self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1) + self.fc0 = paddle.nn.Linear( + in_features=in_channels, + out_features=in_channels, ) + self.emb = paddle.nn.Embedding( + num_embeddings=self.max_length, embedding_dim=in_channels) + self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2) + self.fc1 = paddle.nn.Linear( + in_features=in_channels, out_features=1, bias_attr=False) + + def forward(self, inputs, encoder_word_pos, gsrm_word_pos): + b, c, h, w = inputs.shape + conv_features = paddle.reshape(inputs, shape=[-1, c, h * w]) + conv_features = paddle.transpose(conv_features, perm=[0, 2, 1]) + # transformer encoder + b, t, c = conv_features.shape + + enc_inputs = [conv_features, encoder_word_pos, None] + word_features = self.wrap_encoder_for_feature(enc_inputs) + + # pvam + b, t, c = word_features.shape + word_features = self.fc0(word_features) + word_features_ = paddle.reshape(word_features, [-1, 1, t, c]) + word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1]) + word_pos_feature = self.emb(gsrm_word_pos) + word_pos_feature_ = paddle.reshape(word_pos_feature, + [-1, self.max_length, 1, c]) + word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1]) + y = word_pos_feature_ + word_features_ + y = F.tanh(y) + attention_weight = self.fc1(y) + attention_weight = paddle.reshape( + attention_weight, shape=[-1, self.max_length, t]) + attention_weight = F.softmax(attention_weight, axis=-1) + pvam_features = paddle.matmul(attention_weight, + word_features) #[b, max_length, c] + return pvam_features + + +class GSRM(nn.Layer): + def __init__(self, in_channels, char_num, max_text_length, num_heads, + num_encoder_tus, num_decoder_tus, hidden_dims): + super(GSRM, self).__init__() + self.char_num = char_num + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_tus + self.num_decoder_TUs = num_decoder_tus + self.hidden_dims = hidden_dims + + self.fc0 = paddle.nn.Linear( + in_features=in_channels, out_features=self.char_num) + self.wrap_encoder0 = WrapEncoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + self.wrap_encoder1 = WrapEncoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + self.mul = lambda x: paddle.matmul(x=x, + y=self.wrap_encoder0.prepare_decoder.emb0.weight, + transpose_y=True) + + def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2): + # ===== GSRM Visual-to-semantic embedding block ===== + b, t, c = inputs.shape + pvam_features = paddle.reshape(inputs, [-1, c]) + word_out = self.fc0(pvam_features) + word_ids = paddle.argmax(F.softmax(word_out), axis=1) + word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1]) + + #===== GSRM Semantic reasoning block ===== + """ + This module is achieved through bi-transformers, + ngram_feature1 is the froward one, ngram_fetaure2 is the backward one + """ + pad_idx = self.char_num + + word1 = paddle.cast(word_ids, "float32") + word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC") + word1 = paddle.cast(word1, "int64") + word1 = word1[:, :-1, :] + word2 = word_ids + + enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1] + enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2] + + gsrm_feature1 = self.wrap_encoder0(enc_inputs_1) + gsrm_feature2 = self.wrap_encoder1(enc_inputs_2) + + gsrm_feature2 = F.pad(gsrm_feature2, [0, 1], + value=0., + data_format="NLC") + gsrm_feature2 = gsrm_feature2[:, 1:, ] + gsrm_features = gsrm_feature1 + gsrm_feature2 + + gsrm_out = self.mul(gsrm_features) + + b, t, c = gsrm_out.shape + gsrm_out = paddle.reshape(gsrm_out, [-1, c]) + + return gsrm_features, word_out, gsrm_out + + +class VSFD(nn.Layer): + def __init__(self, in_channels=512, pvam_ch=512, char_num=38): + super(VSFD, self).__init__() + self.char_num = char_num + self.fc0 = paddle.nn.Linear( + in_features=in_channels * 2, out_features=pvam_ch) + self.fc1 = paddle.nn.Linear( + in_features=pvam_ch, out_features=self.char_num) + + def forward(self, pvam_feature, gsrm_feature): + b, t, c1 = pvam_feature.shape + b, t, c2 = gsrm_feature.shape + combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2) + img_comb_feature_ = paddle.reshape( + combine_feature_, shape=[-1, c1 + c2]) + img_comb_feature_map = self.fc0(img_comb_feature_) + img_comb_feature_map = F.sigmoid(img_comb_feature_map) + img_comb_feature_map = paddle.reshape( + img_comb_feature_map, shape=[-1, t, c1]) + combine_feature = img_comb_feature_map * pvam_feature + ( + 1.0 - img_comb_feature_map) * gsrm_feature + img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1]) + + out = self.fc1(img_comb_feature) + return out + + +class SRNHead(nn.Layer): + def __init__(self, in_channels, out_channels, max_text_length, num_heads, + num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs): + super(SRNHead, self).__init__() + self.char_num = out_channels + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_TUs + self.num_decoder_TUs = num_decoder_TUs + self.hidden_dims = hidden_dims + + self.pvam = PVAM( + in_channels=in_channels, + char_num=self.char_num, + max_text_length=self.max_length, + num_heads=self.num_heads, + num_encoder_tus=self.num_encoder_TUs, + hidden_dims=self.hidden_dims) + + self.gsrm = GSRM( + in_channels=in_channels, + char_num=self.char_num, + max_text_length=self.max_length, + num_heads=self.num_heads, + num_encoder_tus=self.num_encoder_TUs, + num_decoder_tus=self.num_decoder_TUs, + hidden_dims=self.hidden_dims) + self.vsfd = VSFD(in_channels=in_channels) + + self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 + + def forward(self, inputs, others): + encoder_word_pos = others[0] + gsrm_word_pos = others[1] + gsrm_slf_attn_bias1 = others[2] + gsrm_slf_attn_bias2 = others[3] + + pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos) + + gsrm_feature, word_out, gsrm_out = self.gsrm( + pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + + final_out = self.vsfd(pvam_feature, gsrm_feature) + if not self.training: + final_out = F.softmax(final_out, axis=1) + + _, decoded_out = paddle.topk(final_out, k=1) + + predicts = OrderedDict([ + ('predict', final_out), + ('pvam_feature', pvam_feature), + ('decoded_out', decoded_out), + ('word_out', word_out), + ('gsrm_out', gsrm_out), + ]) + + return predicts diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py new file mode 100644 index 0000000000..6aeb8f0ccf --- /dev/null +++ b/ppocr/modeling/heads/self_attention.py @@ -0,0 +1,408 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +from paddle import ParamAttr, nn +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import numpy as np +gradient_clip = 10 + + +class WrapEncoderForFeature(nn.Layer): + def __init__(self, + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + bos_idx=0): + super(WrapEncoderForFeature, self).__init__() + + self.prepare_encoder = PrepareEncoder( + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name="src_word_emb_table") + self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, + attention_dropout, relu_dropout, preprocess_cmd, + postprocess_cmd) + + def forward(self, enc_inputs): + conv_features, src_pos, src_slf_attn_bias = enc_inputs + enc_input = self.prepare_encoder(conv_features, src_pos) + enc_output = self.encoder(enc_input, src_slf_attn_bias) + return enc_output + + +class WrapEncoder(nn.Layer): + """ + embedder + encoder + """ + + def __init__(self, + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + bos_idx=0): + super(WrapEncoder, self).__init__() + + self.prepare_decoder = PrepareDecoder( + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx) + self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, + attention_dropout, relu_dropout, preprocess_cmd, + postprocess_cmd) + + def forward(self, enc_inputs): + src_word, src_pos, src_slf_attn_bias = enc_inputs + enc_input = self.prepare_decoder(src_word, src_pos) + enc_output = self.encoder(enc_input, src_slf_attn_bias) + return enc_output + + +class Encoder(nn.Layer): + """ + encoder + """ + + def __init__(self, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + + super(Encoder, self).__init__() + + self.encoder_layers = list() + for i in range(n_layer): + self.encoder_layers.append( + self.add_sublayer( + "layer_%d" % i, + EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid, + prepostprocess_dropout, attention_dropout, + relu_dropout, preprocess_cmd, + postprocess_cmd))) + self.processer = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + + def forward(self, enc_input, attn_bias): + for encoder_layer in self.encoder_layers: + enc_output = encoder_layer(enc_input, attn_bias) + enc_input = enc_output + enc_output = self.processer(enc_output) + return enc_output + + +class EncoderLayer(nn.Layer): + """ + EncoderLayer + """ + + def __init__(self, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + + super(EncoderLayer, self).__init__() + self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head, + attention_dropout) + self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model, + prepostprocess_dropout) + + self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + self.ffn = FFN(d_inner_hid, d_model, relu_dropout) + self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model, + prepostprocess_dropout) + + def forward(self, enc_input, attn_bias): + attn_output = self.self_attn( + self.preprocesser1(enc_input), None, None, attn_bias) + attn_output = self.postprocesser1(attn_output, enc_input) + ffn_output = self.ffn(self.preprocesser2(attn_output)) + ffn_output = self.postprocesser2(ffn_output, attn_output) + return ffn_output + + +class MultiHeadAttention(nn.Layer): + """ + Multi-Head Attention + """ + + def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): + super(MultiHeadAttention, self).__init__() + self.n_head = n_head + self.d_key = d_key + self.d_value = d_value + self.d_model = d_model + self.dropout_rate = dropout_rate + self.q_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_key * n_head, bias_attr=False) + self.k_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_key * n_head, bias_attr=False) + self.v_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_value * n_head, bias_attr=False) + self.proj_fc = paddle.nn.Linear( + in_features=d_value * n_head, out_features=d_model, bias_attr=False) + + def _prepare_qkv(self, queries, keys, values, cache=None): + if keys is None: # self-attention + keys, values = queries, queries + static_kv = False + else: # cross-attention + static_kv = True + + q = self.q_fc(queries) + q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) + q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) + + if cache is not None and static_kv and "static_k" in cache: + # for encoder-decoder attention in inference and has cached + k = cache["static_k"] + v = cache["static_v"] + else: + k = self.k_fc(keys) + v = self.v_fc(values) + k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) + k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) + v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) + v = paddle.transpose(x=v, perm=[0, 2, 1, 3]) + + if cache is not None: + if static_kv and not "static_k" in cache: + # for encoder-decoder attention in inference and has not cached + cache["static_k"], cache["static_v"] = k, v + elif not static_kv: + # for decoder self-attention in inference + cache_k, cache_v = cache["k"], cache["v"] + k = paddle.concat([cache_k, k], axis=2) + v = paddle.concat([cache_v, v], axis=2) + cache["k"], cache["v"] = k, v + + return q, k, v + + def forward(self, queries, keys, values, attn_bias, cache=None): + # compute q ,k ,v + keys = queries if keys is None else keys + values = keys if values is None else values + q, k, v = self._prepare_qkv(queries, keys, values, cache) + + # scale dot product attention + product = paddle.matmul(x=q, y=k, transpose_y=True) + product = product * self.d_model**-0.5 + if attn_bias is not None: + product += attn_bias + weights = F.softmax(product) + if self.dropout_rate: + weights = F.dropout( + weights, p=self.dropout_rate, mode="downscale_in_infer") + out = paddle.matmul(weights, v) + + # combine heads + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.proj_fc(out) + + return out + + +class PrePostProcessLayer(nn.Layer): + """ + PrePostProcessLayer + """ + + def __init__(self, process_cmd, d_model, dropout_rate): + super(PrePostProcessLayer, self).__init__() + self.process_cmd = process_cmd + self.functors = [] + for cmd in self.process_cmd: + if cmd == "a": # add residual connection + self.functors.append(lambda x, y: x + y if y is not None else x) + elif cmd == "n": # add layer normalization + self.functors.append( + self.add_sublayer( + "layer_norm_%d" % len( + self.sublayers(include_sublayers=False)), + paddle.nn.LayerNorm( + normalized_shape=d_model, + weight_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(1.)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.))))) + elif cmd == "d": # add dropout + self.functors.append(lambda x: F.dropout( + x, p=dropout_rate, mode="downscale_in_infer") + if dropout_rate else x) + + def forward(self, x, residual=None): + for i, cmd in enumerate(self.process_cmd): + if cmd == "a": + x = self.functors[i](x, residual) + else: + x = self.functors[i](x) + return x + + +class PrepareEncoder(nn.Layer): + def __init__(self, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0, + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + super(PrepareEncoder, self).__init__() + self.src_emb_dim = src_emb_dim + self.src_max_len = src_max_len + self.emb = paddle.nn.Embedding( + num_embeddings=self.src_max_len, + embedding_dim=self.src_emb_dim, + sparse=True) + self.dropout_rate = dropout_rate + + def forward(self, src_word, src_pos): + src_word_emb = src_word + src_word_emb = fluid.layers.cast(src_word_emb, 'float32') + src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) + src_pos = paddle.squeeze(src_pos, axis=-1) + src_pos_enc = self.emb(src_pos) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + if self.dropout_rate: + out = F.dropout( + x=enc_input, p=self.dropout_rate, mode="downscale_in_infer") + else: + out = enc_input + return out + + +class PrepareDecoder(nn.Layer): + def __init__(self, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0, + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + super(PrepareDecoder, self).__init__() + self.src_emb_dim = src_emb_dim + """ + self.emb0 = Embedding(num_embeddings=src_vocab_size, + embedding_dim=src_emb_dim) + """ + self.emb0 = paddle.nn.Embedding( + num_embeddings=src_vocab_size, + embedding_dim=self.src_emb_dim, + weight_attr=paddle.ParamAttr( + name=word_emb_param_name, + initializer=nn.initializer.Normal(0., src_emb_dim**-0.5))) + self.emb1 = paddle.nn.Embedding( + num_embeddings=src_max_len, + embedding_dim=self.src_emb_dim, + weight_attr=paddle.ParamAttr(name=pos_enc_param_name)) + self.dropout_rate = dropout_rate + + def forward(self, src_word, src_pos): + src_word = fluid.layers.cast(src_word, 'int64') + src_word = paddle.squeeze(src_word, axis=-1) + src_word_emb = self.emb0(src_word) + src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) + src_pos = paddle.squeeze(src_pos, axis=-1) + src_pos_enc = self.emb1(src_pos) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + if self.dropout_rate: + out = F.dropout( + x=enc_input, p=self.dropout_rate, mode="downscale_in_infer") + else: + out = enc_input + return out + + +class FFN(nn.Layer): + """ + Feed-Forward Network + """ + + def __init__(self, d_inner_hid, d_model, dropout_rate): + super(FFN, self).__init__() + self.dropout_rate = dropout_rate + self.fc1 = paddle.nn.Linear( + in_features=d_model, out_features=d_inner_hid) + self.fc2 = paddle.nn.Linear( + in_features=d_inner_hid, out_features=d_model) + + def forward(self, x): + hidden = self.fc1(x) + hidden = F.relu(hidden) + if self.dropout_rate: + hidden = F.dropout( + hidden, p=self.dropout_rate, mode="downscale_in_infer") + out = self.fc2(hidden) + return out diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index c9b42e0839..0156e438e9 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -26,11 +26,12 @@ def build_post_process(config, global_config=None): from .db_postprocess import DBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode + from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode from .cls_postprocess import ClsPostProcess support_dict = [ - 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' + 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index a18e101bf4..c2303cead2 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -29,6 +29,9 @@ def __init__(self, assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) + self.beg_str = "sos" + self.end_str = "eos" + if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -104,7 +107,6 @@ def __init__(self, def __call__(self, preds, label=None, *args, **kwargs): if isinstance(preds, paddle.Tensor): preds = preds.numpy() - preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob) @@ -153,3 +155,83 @@ def get_beg_end_flag_idx(self, beg_or_end): assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class SRNLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='en', + use_space_char=False, + **kwargs): + super(SRNLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + pred = preds['predict'] + char_num = len(self.character_str) + 2 + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = np.reshape(pred, [-1, char_num]) + + preds_idx = np.argmax(pred, axis=1) + preds_prob = np.max(pred, axis=1) + + preds_idx = np.reshape(preds_idx, [-1, 25]) + + preds_prob = np.reshape(preds_prob, [-1, 25]) + + text = self.decode(preds_idx, preds_prob) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def decode(self, text_index, text_prob=None, is_remove_duplicate=True): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx diff --git a/tools/export_model.py b/tools/export_model.py index 74357d58ec..58dc0defa8 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -31,6 +31,14 @@ from tools.program import load_config, merge_config, ArgsParser +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", help="configuration file to use") + parser.add_argument( + "-o", "--output_path", type=str, default='./output/infer/') + return parser.parse_args() + + def main(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -51,14 +59,33 @@ def main(): model.eval() save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - infer_shape = [3, 32, 100] if config['Architecture'][ - 'model_type'] != "det" else [3, 640, 640] - model = to_static( - model, - input_spec=[ + + if config['Architecture']['algorithm'] == "SRN": + other_shape = [ paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) + shape=[None, 1, 64, 256], dtype='float32'), [ + paddle.static.InputSpec( + shape=[None, 256, 1], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 25, 1], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 8, 25, 25], dtype="int64"), + paddle.static.InputSpec( + shape=[None, 8, 25, 25], dtype="int64") + ] + ] + model = to_static(model, input_spec=other_shape) + + else: + infer_shape = [3, 32, 100] if config['Architecture'][ + 'model_type'] != "det" else [3, 640, 640] + model = to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype='float32') + ]) + paddle.jit.save(model, save_path) logger.info('inference model is saved to {}'.format(save_path)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 974fdbb6c7..fd895e5071 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -25,6 +25,7 @@ import math import time import traceback +import paddle import tools.infer.utility as utility from ppocr.postprocess import build_post_process @@ -46,6 +47,13 @@ def __init__(self, args): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + if self.rec_algorithm == "SRN": + postprocess_params = { + 'name': 'SRNLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors = \ utility.create_predictor(args, 'rec', logger) @@ -70,6 +78,78 @@ def resize_norm_img(self, img, max_wh_ratio): padding_im[:, :, 0:resized_w] = resized_image return padding_im + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, img, image_shape, num_heads, max_text_length): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + encoder_word_pos = encoder_word_pos.astype(np.int64) + gsrm_word_pos = gsrm_word_pos.astype(np.int64) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -93,21 +173,64 @@ def __call__(self, img_list): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) + if self.rec_algorithm != "SRN": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn( + img_list[indices[ino]], self.rec_image_shape, 8, 25) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - preds = outputs[0] + + if self.rec_algorithm == "SRN": + starttime = time.time() + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, + encoder_word_pos_list, + gsrm_word_pos_list, + gsrm_slf_attn_bias1_list, + gsrm_slf_attn_bias2_list, + ] + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[ + i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = {"predict": outputs[2]} + else: + starttime = time.time() + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = outputs[0] + rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 7e4b081140..075ec261e4 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -62,7 +62,13 @@ def main(): elif op_name in ['RecResizeImg']: op[op_name]['infer_mode'] = True elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image'] + if config['Architecture']['algorithm'] == "SRN": + op[op_name]['keep_keys'] = [ + 'image', 'encoder_word_pos', 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' + ] + else: + op[op_name]['keep_keys'] = ['image'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) @@ -74,10 +80,25 @@ def main(): img = f.read() data = {'image': img} batch = transform(data, ops) + if config['Architecture']['algorithm'] == "SRN": + encoder_word_pos_list = np.expand_dims(batch[1], axis=0) + gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) + gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) + gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) + + others = [ + paddle.to_tensor(encoder_word_pos_list), + paddle.to_tensor(gsrm_word_pos_list), + paddle.to_tensor(gsrm_slf_attn_bias1_list), + paddle.to_tensor(gsrm_slf_attn_bias2_list) + ] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) - preds = model(images) + if config['Architecture']['algorithm'] == "SRN": + preds = model(images, others) + else: + preds = model(images) post_result = post_process_class(preds) for rec_reuslt in post_result: logger.info('\t result: {}'.format(rec_reuslt)) diff --git a/tools/program.py b/tools/program.py index c291542685..ce52a61049 100755 --- a/tools/program.py +++ b/tools/program.py @@ -179,9 +179,9 @@ def train(config, if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: - start_epoch = 1 + start_epoch = 0 - for epoch in range(start_epoch, epoch_num + 1): + for epoch in range(start_epoch, epoch_num): if epoch > 0: train_dataloader = build_dataloader(config, 'Train', device, logger) train_batch_cost = 0.0 @@ -194,7 +194,11 @@ def train(config, break lr = optimizer.get_lr() images = batch[0] - preds = model(images) + if config['Architecture']['algorithm'] == "SRN": + others = batch[-4:] + preds = model(images, others) + else: + preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -212,6 +216,7 @@ def train(config, stats['lr'] = lr train_stats.update(stats) + #cal_metric_during_train = False if cal_metric_during_train: # onlt rec and cls need batch = [item.numpy() for item in batch] post_result = post_process_class(preds, batch[1]) @@ -312,8 +317,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class): if idx >= len(valid_dataloader): break images = batch[0] + others = batch[-4:] start = time.time() - preds = model(images) + preds = model(images, others) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods From 297871d4be965b760da6ed1535fad82354cfd366 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 30 Dec 2020 19:54:16 +0800 Subject: [PATCH 02/10] fix bugs --- ppocr/metrics/__init__.py | 1 - tools/program.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 41828f510a..a0e7d91207 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,7 +26,6 @@ def build_metric(config): from .det_metric import DetMetric from .rec_metric import RecMetric from .cls_metric import ClsMetric - from .rec_metric import RecMetric support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] diff --git a/tools/program.py b/tools/program.py index ce52a61049..08bc4c81a2 100755 --- a/tools/program.py +++ b/tools/program.py @@ -179,9 +179,9 @@ def train(config, if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: - start_epoch = 0 + start_epoch = 1 - for epoch in range(start_epoch, epoch_num): + for epoch in range(start_epoch, epoch_num + 1): if epoch > 0: train_dataloader = build_dataloader(config, 'Train', device, logger) train_batch_cost = 0.0 @@ -216,7 +216,6 @@ def train(config, stats['lr'] = lr train_stats.update(stats) - #cal_metric_during_train = False if cal_metric_during_train: # onlt rec and cls need batch = [item.numpy() for item in batch] post_result = post_process_class(preds, batch[1]) From 841adff934bbef4967b64bbf029a9bc454578adf Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Thu, 31 Dec 2020 07:05:06 +0000 Subject: [PATCH 03/10] Fix syntax warning over comparison of literals using is. --- PPOCRLabel/PPOCRLabel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py index b4c73083c4..4d2108e67e 100644 --- a/PPOCRLabel/PPOCRLabel.py +++ b/PPOCRLabel/PPOCRLabel.py @@ -1032,7 +1032,7 @@ def format_shape(s): for box in self.result_dic: trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False} - if trans_dic["label"] is "" and mode == 'Auto': + if trans_dic["label"] == "" and mode == 'Auto': continue shapes.append(trans_dic) @@ -1791,7 +1791,7 @@ def reRecognition(self): QMessageBox.information(self, "Information", msg) return result = self.ocr.ocr(img_crop, cls=True, det=False) - if result[0][0] is not '': + if result[0][0] != '': result.insert(0, box) print('result in reRec is ', result) self.result_dic.append(result) @@ -1822,7 +1822,7 @@ def singleRerecognition(self): QMessageBox.information(self, "Information", msg) return result = self.ocr.ocr(img_crop, cls=True, det=False) - if result[0][0] is not '': + if result[0][0] != '': result.insert(0, box) print('result in reRec is ', result) if result[1][0] == shape.label: @@ -2008,7 +2008,7 @@ def main(): resource_file = './libs/resources.py' if not os.path.exists(resource_file): output = os.system('pyrcc5 -o libs/resources.py resources.qrc') - assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \ + assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \ "directory resources.py " import libs.resources sys.exit(main()) From 93670ab5a2dc59d589f82e0c1a952e295ef3c86e Mon Sep 17 00:00:00 2001 From: tink2123 Date: Tue, 19 Jan 2021 06:48:52 +0000 Subject: [PATCH 04/10] all ready --- configs/rec/rec_r50_fpn_srn.yml | 9 +++++---- ppocr/modeling/heads/self_attention.py | 1 + ppocr/postprocess/rec_postprocess.py | 7 ++++--- tools/program.py | 7 +++++++ 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml index 78f8d55102..ec7f170560 100644 --- a/configs/rec/rec_r50_fpn_srn.yml +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -3,7 +3,7 @@ Global: epoch_num: 72 log_smooth_window: 20 print_batch_step: 5 - save_model_dir: ./output/rec/srn + save_model_dir: ./output/rec/srn_new save_epoch_step: 3 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: [0, 5000] @@ -25,8 +25,10 @@ Global: Optimizer: name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 10.0 lr: - name: Cosine learning_rate: 0.0001 Architecture: @@ -58,7 +60,6 @@ Train: dataset: name: LMDBDataSet data_dir: ./train_data/srn_train_data_duiqi - #label_file_list: ["./train_data/ic15_data/1.txt"] transforms: - DecodeImage: # load image img_mode: BGR @@ -77,7 +78,7 @@ Train: loader: shuffle: False batch_size_per_card: 64 - drop_last: True + drop_last: False num_workers: 4 Eval: diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py index 6aeb8f0ccf..51d5198f55 100644 --- a/ppocr/modeling/heads/self_attention.py +++ b/ppocr/modeling/heads/self_attention.py @@ -359,6 +359,7 @@ def __init__(self, self.emb0 = paddle.nn.Embedding( num_embeddings=src_vocab_size, embedding_dim=self.src_emb_dim, + padding_idx=bos_idx, weight_attr=paddle.ParamAttr( name=word_emb_param_name, initializer=nn.initializer.Normal(0., src_emb_dim**-0.5))) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index c2303cead2..867f920a3c 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -182,14 +182,15 @@ def __call__(self, preds, label=None, *args, **kwargs): preds_prob = np.reshape(preds_prob, [-1, 25]) - text = self.decode(preds_idx, preds_prob) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) if label is None: + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) return text - label = self.decode(label, is_remove_duplicate=False) + label = self.decode(label, is_remove_duplicate=True) return text, label - def decode(self, text_index, text_prob=None, is_remove_duplicate=True): + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """ convert text-index into text-label. """ result_list = [] ignored_tokens = self.get_ignored_tokens() diff --git a/tools/program.py b/tools/program.py index 08bc4c81a2..885d45f5e9 100755 --- a/tools/program.py +++ b/tools/program.py @@ -242,6 +242,12 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: + model_average = paddle.optimizer.ModelAverage( + 0.15, + parameters=model.parameters(), + min_average_window=10000, + max_average_window=15625) + model_average.apply() cur_metirc = eval(model, valid_dataloader, post_process_class, eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join( @@ -277,6 +283,7 @@ def train(config, best_model_dict[main_indicator], global_step) global_step += 1 + optimizer.clear_grad() batch_start = time.time() if dist.get_rank() == 0: save_model( From ed2f0de95e58298ee733ee83976ef43079a613a0 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 22 Jan 2021 03:15:56 +0000 Subject: [PATCH 05/10] mv model_average to incubate --- ppocr/losses/rec_srn_loss.py | 2 +- ppocr/postprocess/rec_postprocess.py | 4 ++-- tools/program.py | 15 +++++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py index d722ee0f22..7d5b65ebaf 100644 --- a/ppocr/losses/rec_srn_loss.py +++ b/ppocr/losses/rec_srn_loss.py @@ -42,6 +42,6 @@ def forward(self, predicts, batch): cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) - sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 + sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15 return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 867f920a3c..8c972a143b 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -182,12 +182,12 @@ def __call__(self, preds, label=None, *args, **kwargs): preds_prob = np.reshape(preds_prob, [-1, 25]) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob) if label is None: text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) return text - label = self.decode(label, is_remove_duplicate=True) + label = self.decode(label) return text, label def decode(self, text_index, text_prob=None, is_remove_duplicate=False): diff --git a/tools/program.py b/tools/program.py index 885d45f5e9..f329dcd578 100755 --- a/tools/program.py +++ b/tools/program.py @@ -174,6 +174,7 @@ def train(config, best_model_dict = {main_indicator: 0} best_model_dict.update(pre_best_model_dict) train_stats = TrainingStats(log_smooth_window, ['lr']) + model_average = False model.train() if 'start_epoch' in best_model_dict: @@ -197,6 +198,7 @@ def train(config, if config['Architecture']['algorithm'] == "SRN": others = batch[-4:] preds = model(images, others) + model_average = True else: preds = model(images) loss = loss_class(preds, batch) @@ -242,12 +244,13 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: - model_average = paddle.optimizer.ModelAverage( - 0.15, - parameters=model.parameters(), - min_average_window=10000, - max_average_window=15625) - model_average.apply() + if model_average: + Model_Average = paddle.incubate.optimizer.ModelAverage( + 0.15, + parameters=model.parameters(), + min_average_window=10000, + max_average_window=15625) + Model_Average.apply() cur_metirc = eval(model, valid_dataloader, post_process_class, eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join( From 647db30f6f3bd4d8e3693d6ba83b2d0fea355076 Mon Sep 17 00:00:00 2001 From: Leif <4603009@qq.com> Date: Fri, 29 Jan 2021 14:51:40 +0800 Subject: [PATCH 06/10] Fix bugs during save recognition results --- PPOCRLabel/PPOCRLabel.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py index 4d9c52740a..92d80c8aa2 100644 --- a/PPOCRLabel/PPOCRLabel.py +++ b/PPOCRLabel/PPOCRLabel.py @@ -1450,7 +1450,7 @@ def importDirImages(self, dirpath, isDelete = False): item = QListWidgetItem(closeicon, filename) self.fileListWidget.addItem(item) - print('dirPath in importDirImages is', dirpath) + print('DirPath in importDirImages is', dirpath) self.iconlist.clear() self.additems5(dirpath) self.changeFileFolder = True @@ -1459,7 +1459,6 @@ def importDirImages(self, dirpath, isDelete = False): self.reRecogButton.setEnabled(True) self.actions.AutoRec.setEnabled(True) self.actions.reRec.setEnabled(True) - self.actions.saveLabel.setEnabled(True) def openPrevImg(self, _value=False): @@ -1862,6 +1861,8 @@ def loadFilestate(self, saveDir): for each in states: file, state = each.split('\t') self.fileStatedict[file] = 1 + self.actions.saveLabel.setEnabled(True) + self.actions.saveRec.setEnabled(True) def saveFilestate(self): @@ -1919,22 +1920,29 @@ def saveRecResult(self): rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt' crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/' + ques_img = [] if not os.path.exists(crop_img_dir): os.mkdir(crop_img_dir) with open(rec_gt_dir, 'w', encoding='utf-8') as f: for key in self.fileStatedict: idx = self.getImglabelidx(key) - for i, label in enumerate(self.PPlabel[idx]): - if label['difficult']: continue + try: img = cv2.imread(key) - img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32)) - img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg' - cv2.imwrite(crop_img_dir+img_name, img_crop) - f.write('crop_img/'+ img_name + '\t') - f.write(label['transcription'] + '\n') - - QMessageBox.information(self, "Information", "Cropped images has been saved in "+str(crop_img_dir)) + for i, label in enumerate(self.PPlabel[idx]): + if label['difficult']: continue + img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32)) + img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg' + cv2.imwrite(crop_img_dir+img_name, img_crop) + f.write('crop_img/'+ img_name + '\t') + f.write(label['transcription'] + '\n') + except Exception as e: + ques_img.append(key) + print("Can not read image ",e) + if ques_img: + QMessageBox.information(self, "Information", "The following images can not be saved, " + "please check the image path and labels.\n" + "".join(str(i)+'\n' for i in ques_img)) + QMessageBox.information(self, "Information", "Cropped images have been saved in "+str(crop_img_dir)) def speedChoose(self): if self.labelDialogOption.isChecked(): From b3a451da2672a4ca9f0825f8c720057a91fe35f6 Mon Sep 17 00:00:00 2001 From: Leif <4603009@qq.com> Date: Fri, 29 Jan 2021 15:03:41 +0800 Subject: [PATCH 07/10] Fix a spelling mistake --- ppocr/data/lmdb_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index bd0630f635..e2d6dc9327 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -20,9 +20,9 @@ from .imaug import transform, create_operators -class LMDBDateSet(Dataset): +class LMDBDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): - super(LMDBDateSet, self).__init__() + super(LMDBDataSet, self).__init__() global_config = config['Global'] dataset_config = config[mode]['dataset'] From 42fe741ff18381df2fc00b665f0b4585ab065fd7 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 29 Jan 2021 15:08:58 +0800 Subject: [PATCH 08/10] add srn doc --- doc/doc_ch/algorithm_overview.md | 4 ++-- doc/doc_ch/inference.md | 21 +++++++++++++++++---- doc/doc_ch/recognition.md | 2 ++ doc/doc_en/algorithm_overview_en.md | 4 ++-- doc/doc_en/inference_en.md | 20 ++++++++++++++++++-- doc/doc_en/recognition_en.md | 1 + 6 files changed, 42 insertions(+), 10 deletions(-) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 59d1bc8c44..f076569509 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -41,7 +41,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon -- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon +- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -53,5 +53,5 @@ PaddleOCR基于动态图开源的文本识别算法列表: |CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| - +|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md index c4601e1526..0daddd9bb0 100755 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -22,8 +22,9 @@ inference 模型(`paddle.jit.save`保存的模型) - [三、文本识别模型推理](#文本识别模型推理) - [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理) - [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理) - - [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - - [4. 多语言模型的推理](#多语言模型的推理) + - [3. 基于SRN损失的识别模型推理](#基于SRN损失的识别模型推理) + - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) + - [5. 多语言模型的推理](#多语言模型的推理) - [四、方向分类模型推理](#方向识别模型推理) - [1. 方向分类模型推理](#方向分类模型推理) @@ -295,8 +296,20 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073) self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) ``` + +### 3. 基于SRN损失的识别模型推理 +基于SRN损失的识别模型,需要额外设置识别算法参数 --rec_algorithm="SRN"。 +同时需要保证预测shape与训练时一致,如: --rec_image_shape="1, 64, 256" -### 3. 自定义文本识别字典的推理 +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \ + --rec_model_dir="./inference/srn/" \ + --rec_image_shape="1, 64, 256" \ + --rec_char_type="en" \ + --rec_algorithm="SRN" +``` + +### 4. 自定义文本识别字典的推理 如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch` ``` @@ -304,7 +317,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png ``` -### 4. 多语言模型的推理 +### 5. 多语言模型的推理 如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果, 需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别: diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index c5f459bdb8..bc877ab78c 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -36,6 +36,7 @@ ln -sf /train_data/dataset * 数据下载 若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。 +如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。 * 使用自己数据集 @@ -200,6 +201,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | +| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 68bfd52997..5016223f25 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -43,7 +43,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon -- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon +- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -55,5 +55,5 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| - +|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md index ccbb71847d..c8ce1424f5 100755 --- a/doc/doc_en/inference_en.md +++ b/doc/doc_en/inference_en.md @@ -25,6 +25,7 @@ Next, we first introduce how to convert a trained model into an inference model, - [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE) - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION) - [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION) + - [3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE](#SRN-BASED_RECOGNITION) - [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS) - [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE) @@ -304,8 +305,23 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) ``` + +### 3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE + +The recognition model based on SRN requires additional setting of the recognition algorithm parameter +--rec_algorithm="SRN". At the same time, it is necessary to ensure that the predicted shape is consistent +with the training, such as: --rec_image_shape="1, 64, 256" + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \ + --rec_model_dir="./inference/srn/" \ + --rec_image_shape="1, 64, 256" \ + --rec_char_type="en" \ + --rec_algorithm="SRN" +``` + -### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY +### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch` ``` @@ -313,7 +329,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png ``` -### 4. MULTILINGAUL MODEL INFERENCE +### 5. MULTILINGAUL MODEL INFERENCE If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results, You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition: diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 22f89cdef0..f29703d144 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -195,6 +195,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | +| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | For training Chinese data, it is recommended to use [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: From 6781d55df4a705b1d0d7201e5fc6b484d4912a9b Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 29 Jan 2021 15:23:11 +0800 Subject: [PATCH 09/10] format doc --- doc/doc_ch/algorithm_overview.md | 1 + doc/doc_en/algorithm_overview_en.md | 1 + 2 files changed, 2 insertions(+) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index f076569509..abbc5da4c2 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -54,4 +54,5 @@ PaddleOCR基于动态图开源的文本识别算法列表: |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | + PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 5016223f25..7d7896e714 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -56,4 +56,5 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| + Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) From 2a0c3d4dac67cfd49e432303443bb9a50e75071f Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Sun, 31 Jan 2021 22:37:30 +0800 Subject: [PATCH 10/10] fix eval mode without srn (#1889) * fix base model * fix start time --- tools/program.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/program.py b/tools/program.py index 694d64152f..f3ba49450a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -326,9 +326,12 @@ def eval(model, valid_dataloader, post_process_class, eval_class): if idx >= len(valid_dataloader): break images = batch[0] - others = batch[-4:] start = time.time() - preds = model(images, others) + if "SRN" in str(model.head): + others = batch[-4:] + preds = model(images, others) + else: + preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods