diff --git a/configs/rec/rec_r31_sar.yml b/configs/rec/rec_r31_sar.yml new file mode 100644 index 0000000000..053b1ae835 --- /dev/null +++ b/configs/rec/rec_r31_sar.yml @@ -0,0 +1,99 @@ +Global: + use_gpu: true + epoch_num: 5 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./sar_rec + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + # for data or label process + character_dict_path: ppocr/utils/dict90.txt + character_type: EN_symbol + max_text_length: 30 + infer_mode: False + use_space_char: False + rm_symbol: True + save_res_path: ./output/rec/predicts_sar.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4] + values: [0.001, 0.0001, 0.00001] + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: SAR + Transform: + Backbone: + name: ResNet31 + Head: + name: SARHead + +Loss: + name: SARLoss + +PostProcess: + name: SARLabelDecode + +Metric: + name: RecMetric + + +Train: + dataset: + name: SimpleDataSet + label_file_list: ['./train_data/train_list.txt'] + data_dir: ./train_data/ + ratio_list: 1.0 + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - SARRecResizeImg: + image_shape: [3, 48, 48, 160] # h:48 w:[48,160] + width_downsample_ratio: 0.25 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 64 + drop_last: True + num_workers: 8 + use_shared_memory: False + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./eval_data/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - SARRecResizeImg: + image_shape: [3, 48, 48, 160] + width_downsample_ratio: 0.25 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 64 + num_workers: 4 + use_shared_memory: False + diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index e8f23b54e5..d465539166 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -45,6 +45,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] - [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) +- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -60,6 +61,6 @@ PaddleOCR基于动态图开源的文本识别算法列表: |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | |NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | - +|SAR|Resnet31| 87.2% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index cd1a64bd7b..812e2fbd3e 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -88,7 +88,10 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单 若您本地没有数据集,可以在官网下载 [ICDAR2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,下载 benchmark 所需的lmdb格式数据集。 +如果希望复现SAR的论文指标,需要下载[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), 提取码:627x。此外,真实数据集icdar2013, icdar2015, cocotext, IIIT5也作为训练数据的一部分。具体数据细节可以参考论文SAR。 + 如果你使用的是icdar2015的公开数据集,PaddleOCR 提供了一份用于训练 ICDAR2015 数据集的标签文件,通过以下方式下载: + ``` # 训练集标签 wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt @@ -232,6 +235,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | | rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | +| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 8e8f0d3f8c..0572815323 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -47,6 +47,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] - [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) +- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -62,5 +63,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| |NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | +|SAR|Resnet31| 87.2% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 21233c80a2..fdc5103a42 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -91,6 +91,8 @@ Similar to the training set, the test set also needs to be provided a folder con If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads). Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,download the lmdb format dataset required for benchmark +If you want to reproduce the paper SAR, you need to download extra dataset [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), extraction code: 627x. Besides, icdar2013, icdar2015, cocotext, IIIT5k datasets are also used to train. For specific details, please refer to the paper SAR. + PaddleOCR provides label files for training the icdar2015 dataset, which can be downloaded in the following ways: ``` @@ -235,6 +237,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | | rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | +| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder | + For training Chinese data, it is recommended to use [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 4418d075cb..8bfc175083 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index f626395095..643ec70503 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -549,3 +549,49 @@ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): assert False, "Unsupport type %s in char_or_elem" \ % char_or_elem return idx + + +class SARLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SARLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len - 1: + return None + data['length'] = np.array(len(text)) + target = [self.start_idx] + text + [self.end_idx] + padded_text = [self.padding_idx for _ in range(self.max_text_len)] + + padded_text[:len(target)] = target + data['label'] = np.array(padded_text) + return data + + def get_ignored_tokens(self): + return [self.padding_idx] diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index e914d38446..51f5855ac3 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -102,6 +102,56 @@ def __call__(self, data): return data +class SARRecResizeImg(object): + def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs): + self.image_shape = image_shape + self.width_downsample_ratio = width_downsample_ratio + + def __call__(self, data): + img = data['image'] + norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio) + data['image'] = norm_img + data['resized_shape'] = resize_shape + data['pad_shape'] = pad_shape + data['valid_ratio'] = valid_ratio + return data + + +def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + + def resize_norm_img(img, image_shape): imgC, imgH, imgW = image_shape h = img.shape[0] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index eed5a46efc..0484542f0b 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -26,6 +26,7 @@ from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss from .rec_nrtr_loss import NRTRLoss +from .rec_sar_loss import SARLoss # cls loss from .cls_loss import ClsLoss @@ -44,7 +45,7 @@ def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss' + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss' ] config = copy.deepcopy(config) diff --git a/ppocr/losses/rec_sar_loss.py b/ppocr/losses/rec_sar_loss.py new file mode 100644 index 0000000000..9e1c6495fb --- /dev/null +++ b/ppocr/losses/rec_sar_loss.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class SARLoss(nn.Layer): + def __init__(self, **kwargs): + super(SARLoss, self).__init__() + self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96) + + def forward(self, predicts, batch): + predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets + label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation + batch_size, num_steps, num_classes = predict.shape[0], predict.shape[ + 1], predict.shape[2] + assert len(label.shape) == len(list(predict.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = paddle.reshape(predict, [-1, num_classes]) + targets = paddle.reshape(label, [-1]) + loss = self.loss_func(inputs, targets) + return {'loss': loss} diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index f8ca7e408a..50438893c7 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -27,8 +27,9 @@ def build_backbone(config, model_type): from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB + from .rec_resnet_31 import ResNet31 support_dict = [ - 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB' + 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_resnet_31.py b/ppocr/modeling/backbones/rec_resnet_31.py new file mode 100644 index 0000000000..f60729cdcc --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_31.py @@ -0,0 +1,176 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + +__all__ = ["ResNet31"] + + +def conv3x3(in_channel, out_channel, stride=1): + return nn.Conv2D( + in_channel, + out_channel, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False + ) + + +class BasicBlock(nn.Layer): + expansion = 1 + def __init__(self, in_channels, channels, stride=1, downsample=False): + super().__init__() + self.conv1 = conv3x3(in_channels, channels, stride) + self.bn1 = nn.BatchNorm2D(channels) + self.relu = nn.ReLU() + self.conv2 = conv3x3(channels, channels) + self.bn2 = nn.BatchNorm2D(channels) + self.downsample = downsample + if downsample: + self.downsample = nn.Sequential( + nn.Conv2D(in_channels, channels * self.expansion, 1, stride, bias_attr=False), + nn.BatchNorm2D(channels * self.expansion), + ) + else: + self.downsample = nn.Sequential() + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet31(nn.Layer): + ''' + Args: + in_channels (int): Number of channels of input image tensor. + layers (list[int]): List of BasicBlock number for each stage. + channels (list[int]): List of out_channels of Conv2d layer. + out_indices (None | Sequence[int]): Indices of output stages. + last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. + ''' + def __init__(self, + in_channels=3, + layers=[1, 2, 5, 3], + channels=[64, 128, 256, 256, 512, 512, 512], + out_indices=None, + last_stage_pool=False): + super(ResNet31, self).__init__() + assert isinstance(in_channels, int) + assert isinstance(last_stage_pool, bool) + + self.out_indices = out_indices + self.last_stage_pool = last_stage_pool + + # conv 1 (Conv Conv) + self.conv1_1 = nn.Conv2D(in_channels, channels[0], kernel_size=3, stride=1, padding=1) + self.bn1_1 = nn.BatchNorm2D(channels[0]) + self.relu1_1 = nn.ReLU() + + self.conv1_2 = nn.Conv2D(channels[0], channels[1], kernel_size=3, stride=1, padding=1) + self.bn1_2 = nn.BatchNorm2D(channels[1]) + self.relu1_2 = nn.ReLU() + + # conv 2 (Max-pooling, Residual block, Conv) + self.pool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block2 = self._make_layer(channels[1], channels[2], layers[0]) + self.conv2 = nn.Conv2D(channels[2], channels[2], kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2D(channels[2]) + self.relu2 = nn.ReLU() + + # conv 3 (Max-pooling, Residual block, Conv) + self.pool3 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block3 = self._make_layer(channels[2], channels[3], layers[1]) + self.conv3 = nn.Conv2D(channels[3], channels[3], kernel_size=3, stride=1, padding=1) + self.bn3 = nn.BatchNorm2D(channels[3]) + self.relu3 = nn.ReLU() + + # conv 4 (Max-pooling, Residual block, Conv) + self.pool4 = nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True) + self.block4 = self._make_layer(channels[3], channels[4], layers[2]) + self.conv4 = nn.Conv2D(channels[4], channels[4], kernel_size=3, stride=1, padding=1) + self.bn4 = nn.BatchNorm2D(channels[4]) + self.relu4 = nn.ReLU() + + # conv 5 ((Max-pooling), Residual block, Conv) + self.pool5 = None + if self.last_stage_pool: + self.pool5 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block5 = self._make_layer(channels[4], channels[5], layers[3]) + self.conv5 = nn.Conv2D(channels[5], channels[5], kernel_size=3, stride=1, padding=1) + self.bn5 = nn.BatchNorm2D(channels[5]) + self.relu5 = nn.ReLU() + + self.out_channels = channels[-1] + + def _make_layer(self, input_channels, output_channels, blocks): + layers = [] + for _ in range(blocks): + downsample = None + if input_channels != output_channels: + downsample = nn.Sequential( + nn.Conv2D( + input_channels, + output_channels, + kernel_size=1, + stride=1, + bias_attr=False), + nn.BatchNorm2D(output_channels), + ) + + layers.append(BasicBlock(input_channels, output_channels, downsample=downsample)) + input_channels = output_channels + return nn.Sequential(*layers) + + + def forward(self, x): + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu1_1(x) + + x = self.conv1_2(x) + x = self.bn1_2(x) + x = self.relu1_2(x) + + outs = [] + for i in range(4): + layer_index = i + 2 + pool_layer = getattr(self, f'pool{layer_index}') + block_layer = getattr(self, f'block{layer_index}') + conv_layer = getattr(self, f'conv{layer_index}') + bn_layer = getattr(self, f'bn{layer_index}') + relu_layer = getattr(self, f'relu{layer_index}') + + if pool_layer is not None: + x = pool_layer(x) + x = block_layer(x) + x = conv_layer(x) + x = bn_layer(x) + x= relu_layer(x) + + outs.append(x) + + if self.out_indices is not None: + return tuple([outs[i] for i in self.out_indices]) + + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 572ec4aa8a..80311f92a6 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -27,12 +27,13 @@ def build_head(config): from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead from .rec_nrtr_head import Transformer + from .rec_sar_head import SARHead # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead' + 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead' ] #table head diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py new file mode 100644 index 0000000000..647f58200f --- /dev/null +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -0,0 +1,383 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + + +class SAREncoder(nn.Layer): + """ + Args: + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + enc_drop_rnn (float): Dropout probability of RNN layer in encoder. + enc_gru (bool): If True, use GRU, else LSTM in encoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + mask (bool): If True, mask padding in RNN sequence. + """ + + def __init__(self, + enc_bi_rnn=False, + enc_drop_rnn=0.1, + enc_gru=False, + d_model=512, + d_enc=512, + mask=True, + **kwargs): + super().__init__() + assert isinstance(enc_bi_rnn, bool) + assert isinstance(enc_drop_rnn, (int, float)) + assert 0 <= enc_drop_rnn < 1.0 + assert isinstance(enc_gru, bool) + assert isinstance(d_model, int) + assert isinstance(d_enc, int) + assert isinstance(mask, bool) + + self.enc_bi_rnn = enc_bi_rnn + self.enc_drop_rnn = enc_drop_rnn + self.mask = mask + + # LSTM Encoder + if enc_bi_rnn: + direction = 'bidirectional' + else: + direction = 'forward' + kwargs = dict( + input_size=d_model, + hidden_size=d_enc, + num_layers=2, + time_major=False, + dropout=enc_drop_rnn, + direction=direction) + if enc_gru: + self.rnn_encoder = nn.GRU(**kwargs) + else: + self.rnn_encoder = nn.LSTM(**kwargs) + + # global feature transformation + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) + + def forward(self, feat, img_metas=None): + if img_metas is not None: + assert len(img_metas[0]) == feat.shape[0] + + valid_ratios = None + if img_metas is not None and self.mask: + valid_ratios = img_metas[-1] + + h_feat = feat.shape[2] # bsz c h w + feat_v = F.max_pool2d( + feat, kernel_size=(h_feat, 1), stride=1, padding=0) + feat_v = feat_v.squeeze(2) # bsz * C * W + feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C + holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C + + if valid_ratios is not None: + valid_hf = [] + T = holistic_feat.shape[1] + for i, valid_ratio in enumerate(valid_ratios): + valid_step = min(T, math.ceil(T * valid_ratio)) - 1 + valid_hf.append(holistic_feat[i, valid_step, :]) + valid_hf = paddle.stack(valid_hf, axis=0) + else: + valid_hf = holistic_feat[:, -1, :] # bsz * C + holistic_feat = self.linear(valid_hf) # bsz * C + + return holistic_feat + + +class BaseDecoder(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + + def forward_train(self, feat, out_enc, targets, img_metas): + raise NotImplementedError + + def forward_test(self, feat, out_enc, img_metas): + raise NotImplementedError + + def forward(self, + feat, + out_enc, + label=None, + img_metas=None, + train_mode=True): + self.train_mode = train_mode + + if train_mode: + return self.forward_train(feat, out_enc, label, img_metas) + return self.forward_test(feat, out_enc, img_metas) + + +class ParallelSARDecoder(BaseDecoder): + """ + Args: + out_channels (int): Output class number. + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. + dec_drop_rnn (float): Dropout of RNN layer in decoder. + dec_gru (bool): If True, use GRU, else LSTM in decoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + d_k (int): Dim of channels of attention module. + pred_dropout (float): Dropout probability of prediction layer. + max_seq_len (int): Maximum sequence length for decoding. + mask (bool): If True, mask padding in feature map. + start_idx (int): Index of start token. + padding_idx (int): Index of padding token. + pred_concat (bool): If True, concat glimpse feature from + attention with holistic feature and hidden state. + """ + + def __init__( + self, + out_channels, # 90 + unknown + start + padding + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_drop_rnn=0.0, + dec_gru=False, + d_model=512, + d_enc=512, + d_k=64, + pred_dropout=0.1, + max_text_length=30, + mask=True, + pred_concat=True, + **kwargs): + super().__init__() + + self.num_classes = out_channels + self.enc_bi_rnn = enc_bi_rnn + self.d_k = d_k + self.start_idx = out_channels - 2 + self.padding_idx = out_channels - 1 + self.max_seq_len = max_text_length + self.mask = mask + self.pred_concat = pred_concat + + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) + + # 2D attention layer + self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) + self.conv3x3_1 = nn.Conv2D( + d_model, d_k, kernel_size=3, stride=1, padding=1) + self.conv1x1_2 = nn.Linear(d_k, 1) + + # Decoder RNN layer + if dec_bi_rnn: + direction = 'bidirectional' + else: + direction = 'forward' + + kwargs = dict( + input_size=encoder_rnn_out_size, + hidden_size=encoder_rnn_out_size, + num_layers=2, + time_major=False, + dropout=dec_drop_rnn, + direction=direction) + if dec_gru: + self.rnn_decoder = nn.GRU(**kwargs) + else: + self.rnn_decoder = nn.LSTM(**kwargs) + + # Decoder input embedding + self.embedding = nn.Embedding( + self.num_classes, + encoder_rnn_out_size, + padding_idx=self.padding_idx) + + # Prediction layer + self.pred_dropout = nn.Dropout(pred_dropout) + pred_num_classes = self.num_classes - 1 + if pred_concat: + fc_in_channel = decoder_rnn_out_size + d_model + d_enc + else: + fc_in_channel = d_model + self.prediction = nn.Linear(fc_in_channel, pred_num_classes) + + def _2d_attention(self, + decoder_input, + feat, + holistic_feat, + valid_ratios=None): + + y = self.rnn_decoder(decoder_input)[0] + # y: bsz * (seq_len + 1) * hidden_size + + attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size + bsz, seq_len, attn_size = attn_query.shape + attn_query = paddle.unsqueeze(attn_query, axis=[3, 4]) + # (bsz, seq_len + 1, attn_size, 1, 1) + + attn_key = self.conv3x3_1(feat) + # bsz * attn_size * h * w + attn_key = attn_key.unsqueeze(1) + # bsz * 1 * attn_size * h * w + + attn_weight = paddle.tanh(paddle.add(attn_key, attn_query)) + + # bsz * (seq_len + 1) * attn_size * h * w + attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2]) + # bsz * (seq_len + 1) * h * w * attn_size + attn_weight = self.conv1x1_2(attn_weight) + # bsz * (seq_len + 1) * h * w * 1 + bsz, T, h, w, c = attn_weight.shape + assert c == 1 + + if valid_ratios is not None: + # cal mask of attention weight + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + attn_weight[i, :, :, valid_width:, :] = float('-inf') + + attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) + attn_weight = F.softmax(attn_weight, axis=-1) + + attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c]) + attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) + # attn_weight: bsz * T * c * h * w + # feat: bsz * c * h * w + attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), + (3, 4), + keepdim=False) + # bsz * (seq_len + 1) * C + + # Linear transformation + if self.pred_concat: + hf_c = holistic_feat.shape[-1] + holistic_feat = paddle.expand( + holistic_feat, shape=[bsz, seq_len, hf_c]) + y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2)) + else: + y = self.prediction(attn_feat) + # bsz * (seq_len + 1) * num_classes + if self.train_mode: + y = self.pred_dropout(y) + + return y + + def forward_train(self, feat, out_enc, label, img_metas): + ''' + img_metas: [label, valid_ratio] + ''' + if img_metas is not None: + assert len(img_metas[0]) == feat.shape[0] + + valid_ratios = None + if img_metas is not None and self.mask: + valid_ratios = img_metas[-1] + + label = label.cuda() + lab_embedding = self.embedding(label) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + in_dec = paddle.concat((out_enc, lab_embedding), axis=1) + # bsz * (seq_len + 1) * C + out_dec = self._2d_attention( + in_dec, feat, out_enc, valid_ratios=valid_ratios) + # bsz * (seq_len + 1) * num_classes + + return out_dec[:, 1:, :] # bsz * seq_len * num_classes + + def forward_test(self, feat, out_enc, img_metas): + if img_metas is not None: + assert len(img_metas[0]) == feat.shape[0] + + valid_ratios = None + if img_metas is not None and self.mask: + valid_ratios = img_metas[-1] + + seq_len = self.max_seq_len + bsz = feat.shape[0] + start_token = paddle.full( + (bsz, ), fill_value=self.start_idx, dtype='int64') + # bsz + start_token = self.embedding(start_token) + # bsz * emb_dim + emb_dim = start_token.shape[1] + start_token = start_token.unsqueeze(1) + start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim]) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + decoder_input = paddle.concat((out_enc, start_token), axis=1) + # bsz * (seq_len + 1) * emb_dim + + outputs = [] + for i in range(1, seq_len + 1): + decoder_output = self._2d_attention( + decoder_input, feat, out_enc, valid_ratios=valid_ratios) + char_output = decoder_output[:, i, :] # bsz * num_classes + char_output = F.softmax(char_output, -1) + outputs.append(char_output) + max_idx = paddle.argmax(char_output, axis=1, keepdim=False) + char_embedding = self.embedding(max_idx) # bsz * emb_dim + if i < seq_len: + decoder_input[:, i + 1, :] = char_embedding + + outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes + + return outputs + + +class SARHead(nn.Layer): + def __init__(self, + out_channels, + enc_bi_rnn=False, + enc_drop_rnn=0.1, + enc_gru=False, + dec_bi_rnn=False, + dec_drop_rnn=0.0, + dec_gru=False, + d_k=512, + pred_dropout=0.1, + max_text_length=30, + pred_concat=True, + **kwargs): + super(SARHead, self).__init__() + + # encoder module + self.encoder = SAREncoder( + enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru) + + # decoder module + self.decoder = ParallelSARDecoder( + out_channels=out_channels, + enc_bi_rnn=enc_bi_rnn, + dec_bi_rnn=dec_bi_rnn, + dec_drop_rnn=dec_drop_rnn, + dec_gru=dec_gru, + d_k=d_k, + pred_dropout=pred_dropout, + max_text_length=max_text_length, + pred_concat=pred_concat) + + def forward(self, feat, targets=None): + ''' + img_metas: [label, valid_ratio] + ''' + holistic_feat = self.encoder(feat, targets) # bsz c + + if self.training: + label = targets[0] # label + label = paddle.to_tensor(label, dtype='int64') + final_out = self.decoder( + feat, holistic_feat, label, img_metas=targets) + if not self.training: + final_out = self.decoder( + feat, + holistic_feat, + label=None, + img_metas=targets, + train_mode=False) + # (bsz, seq_len, num_classes) + + return final_out diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 6159398770..77081abeb6 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -25,7 +25,7 @@ from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \ - TableLabelDecode + TableLabelDecode, SARLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess @@ -33,7 +33,8 @@ def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess' + 'DistillationCTCLabelDecode', 'TableLabelDecode', + 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 9f23b5495f..6ff375eb43 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -15,6 +15,7 @@ import string import paddle from paddle.nn import functional as F +import re class BaseRecLabelDecode(object): @@ -165,21 +166,21 @@ def __init__(self, use_space_char=True, **kwargs): super(NRTRLabelDecode, self).__init__(character_dict_path, - character_type, use_space_char) + character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): if preds.dtype == paddle.int64: if isinstance(preds, paddle.Tensor): preds = preds.numpy() - if preds[0][0]==2: - preds_idx = preds[:,1:] + if preds[0][0] == 2: + preds_idx = preds[:, 1:] else: preds_idx = preds text = self.decode(preds_idx) if label is None: return text - label = self.decode(label[:,1:]) + label = self.decode(label[:, 1:]) else: if isinstance(preds, paddle.Tensor): preds = preds.numpy() @@ -188,13 +189,13 @@ def __call__(self, preds, label=None, *args, **kwargs): text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) if label is None: return text - label = self.decode(label[:,1:]) + label = self.decode(label[:, 1:]) return text, label def add_special_char(self, dict_character): - dict_character = ['blank','','',''] + dict_character + dict_character = ['blank', '', '', ''] + dict_character return dict_character - + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """ convert text-index into text-label. """ result_list = [] @@ -203,10 +204,11 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): char_list = [] conf_list = [] for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] == 3: # end + if text_index[batch_idx][idx] == 3: # end break try: - char_list.append(self.character[int(text_index[batch_idx][idx])]) + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) except: continue if text_prob is not None: @@ -218,7 +220,6 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): return result_list - class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ @@ -256,7 +257,8 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ batch_idx][idx]: continue - char_list.append(self.character[int(text_index[batch_idx][idx])]) + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: @@ -386,10 +388,9 @@ def get_beg_end_flag_idx(self, beg_or_end): class TableLabelDecode(object): """ """ - def __init__(self, - character_dict_path, - **kwargs): - list_character, list_elem = self.load_char_elem_dict(character_dict_path) + def __init__(self, character_dict_path, **kwargs): + list_character, list_elem = self.load_char_elem_dict( + character_dict_path) list_character = self.add_special_char(list_character) list_elem = self.add_special_char(list_elem) self.dict_character = {} @@ -408,7 +409,8 @@ def load_char_elem_dict(self, character_dict_path): list_elem = [] with open(character_dict_path, "rb") as fin: lines = fin.readlines() - substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t") + substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split( + "\t") character_num = int(substr[0]) elem_num = int(substr[1]) for cno in range(1, 1 + character_num): @@ -428,14 +430,14 @@ def add_special_char(self, list_character): def __call__(self, preds): structure_probs = preds['structure_probs'] loc_preds = preds['loc_preds'] - if isinstance(structure_probs,paddle.Tensor): + if isinstance(structure_probs, paddle.Tensor): structure_probs = structure_probs.numpy() - if isinstance(loc_preds,paddle.Tensor): + if isinstance(loc_preds, paddle.Tensor): loc_preds = loc_preds.numpy() structure_idx = structure_probs.argmax(axis=2) structure_probs = structure_probs.max(axis=2) - structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx, - structure_probs, 'elem') + structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode( + structure_idx, structure_probs, 'elem') res_html_code_list = [] res_loc_list = [] batch_num = len(structure_str) @@ -450,8 +452,13 @@ def __call__(self, preds): res_loc = np.array(res_loc) res_html_code_list.append(res_html_code) res_loc_list.append(res_loc) - return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list, - 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str} + return { + 'res_html_code': res_html_code_list, + 'res_loc': res_loc_list, + 'res_score_list': result_score_list, + 'res_elem_idx_list': result_elem_idx_list, + 'structure_str_list': structure_str + } def decode(self, text_index, structure_probs, char_or_elem): """convert text-label into text-index. @@ -516,3 +523,82 @@ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): assert False, "Unsupport type %s in char_or_elem" \ % char_or_elem return idx + + +class SARLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SARLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + self.rm_symbol = kwargs.get('rm_symbol', False) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(self.end_idx): + if text_prob is None and idx == 0: + continue + else: + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + if self.rm_symbol: + comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') + text = text.lower() + text = comp.sub('', text) + result_list.append((text, np.mean(conf_list))) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + return [self.padding_idx] diff --git a/tools/eval.py b/tools/eval.py index 7d6fb94f38..39a26ffeff 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -55,6 +55,7 @@ def main(): model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" + use_sar = config['Architecture']['algorithm'] == "SAR" if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: @@ -71,7 +72,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn) + eval_class, model_type, use_srn, use_sar) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/infer_rec.py b/tools/infer_rec.py index cf49348fa6..29d4b530df 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -74,6 +74,10 @@ def main(): 'image', 'encoder_word_pos', 'gsrm_word_pos', 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' ] + elif config['Architecture']['algorithm'] == "SAR": + op[op_name]['keep_keys'] = [ + 'image', 'valid_ratio' + ] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) @@ -106,11 +110,16 @@ def main(): paddle.to_tensor(gsrm_slf_attn_bias1_list), paddle.to_tensor(gsrm_slf_attn_bias2_list) ] + if config['Architecture']['algorithm'] == "SAR": + valid_ratio = np.expand_dims(batch[-1], axis=0) + img_metas = [paddle.to_tensor(valid_ratio)] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config['Architecture']['algorithm'] == "SRN": preds = model(images, others) + elif config['Architecture']['algorithm'] == "SAR": + preds = model(images, img_metas) else: preds = model(images) post_result = post_process_class(preds) diff --git a/tools/program.py b/tools/program.py index e7742a8f60..d6d47d047b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -187,7 +187,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" use_nrtr = config['Architecture']['algorithm'] == "NRTR" - + use_sar = config['Architecture']['algorithm'] == 'SAR' try: model_type = config['Architecture']['model_type'] except: @@ -215,7 +215,7 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table' or use_nrtr: + if use_srn or model_type == 'table' or use_nrtr or use_sar: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -279,7 +279,8 @@ def train(config, post_process_class, eval_class, model_type, - use_srn=use_srn) + use_srn=use_srn, + use_sar=use_sar) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -351,7 +352,8 @@ def eval(model, post_process_class, eval_class, model_type, - use_srn=False): + use_srn=False, + use_sar=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -364,7 +366,7 @@ def eval(model, break images = batch[0] start = time.time() - if use_srn or model_type == 'table': + if use_srn or model_type == 'table' or use_sar: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -400,7 +402,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn' + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'