Skip to content

Commit

Permalink
add rec_nrtr
Browse files Browse the repository at this point in the history
  • Loading branch information
Topdu committed Aug 17, 2021
1 parent b6f0a90 commit 1623c17
Show file tree
Hide file tree
Showing 23 changed files with 638 additions and 23 deletions.
16 changes: 16 additions & 0 deletions configs/rec/rec_mtb_nrtr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,38 @@ Global:
epoch_num: 21
log_smooth_window: 20
print_batch_step: 10
<<<<<<< HEAD
save_model_dir: ./output/rec/nrtr_final/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
=======
save_model_dir: ./output/rec/piloptimnrtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: False
>>>>>>> 9c67a7f... add rec_nrtr
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
<<<<<<< HEAD
character_dict_path:
character_type: EN_symbol
max_text_length: 25
infer_mode: False
use_space_char: True
=======
character_dict_path: ppocr/utils/dict_99.txt
character_type: dict_99
max_text_length: 25
infer_mode: False
use_space_char: False
>>>>>>> 9c67a7f... add rec_nrtr
save_res_path: ./output/rec/predicts_nrtr.txt

Optimizer:
Expand Down
2 changes: 2 additions & 0 deletions doc/doc_ch/algorithm_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] STAR-Net([paper](https://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)

参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:

Expand All @@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |


PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)
1 change: 1 addition & 0 deletions doc/doc_ch/recognition.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |

训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:

Expand Down
1 change: 1 addition & 0 deletions doc/doc_en/algorithm_overview_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |

Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
2 changes: 1 addition & 1 deletion doc/doc_en/recognition_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |

| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |

For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
Expand Down
2 changes: 1 addition & 1 deletion ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop

from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
from .randaugment import RandAugment
from .operators import *
from .label_ops import *
Expand Down
2 changes: 1 addition & 1 deletion ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
Expand Down
32 changes: 32 additions & 0 deletions ppocr/data/imaug/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,38 @@ def __call__(self, data):
return data


class NRTRDecodeImage(object):
""" decode image """

def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first

def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')

img = cv2.imdecode(img, 1)

if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data

class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
Expand Down
30 changes: 29 additions & 1 deletion ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cv2
import numpy as np
import random

from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort


Expand All @@ -42,6 +42,34 @@ def __call__(self, data):
data['image'] = norm_img
return data

class PILResize(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape

def __call__(self, data):
img = data['image']
image_pil = Image.fromarray(np.uint8(img))
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
norm_img = np.array(norm_img)
norm_img = np.expand_dims(norm_img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data


class CVResize(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape

def __call__(self, data):
img = data['image']
#print(img)
norm_img = cv2.resize(img,self.image_shape)
norm_img = np.expand_dims(norm_img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data


class RecResizeImg(object):
def __init__(self,
Expand Down
6 changes: 3 additions & 3 deletions ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss

from .rec_nrtr_loss import NRTRLoss
# cls loss
from .cls_loss import ClsLoss

Expand All @@ -42,8 +42,8 @@
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss'
]
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss']

config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
Expand Down
38 changes: 38 additions & 0 deletions ppocr/losses/rec_nrtr_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import paddle
from paddle import nn
import paddle.nn.functional as F


def cal_performance(pred, tgt):

pred = pred.max(1)[1]
tgt = tgt.contiguous().view(-1)
non_pad_mask = tgt.ne(0)
n_correct = pred.eq(tgt)
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
return n_correct


class NRTRLoss(nn.Layer):
def __init__(self,smoothing=True, **kwargs):
super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0)
self.smoothing = smoothing

def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max()
tgt = batch[1][:,1:2+max_len]
tgt = tgt.reshape([-1] )
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
return {'loss': loss}
7 changes: 4 additions & 3 deletions ppocr/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __call__(self, pred_label, *args, **kwargs):
target = target.replace(" ", "")
norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1)
if pred == target:
if pred.lower() == target.lower():
correct_num += 1
all_num += 1
self.correct_num += correct_num
Expand All @@ -48,12 +48,13 @@ def get_metric(self):
'norm_edit_dis': 0,
}
"""
acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
acc = 1.0 * self.correct_num / (self.all_num)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num)
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}

def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0

2 changes: 1 addition & 1 deletion ppocr/modeling/architectures/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
Expand Down
5 changes: 4 additions & 1 deletion ppocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def build_backbone(config, model_type):
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
from .rec_nrtr_mtb import MTB
from .rec_swin import SwinTransformer
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer']

elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
Expand Down
Loading

0 comments on commit 1623c17

Please sign in to comment.