Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dygraph' into dygraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Evezerest committed Aug 31, 2021
2 parents 1f9ad0f + 63ed5fc commit 960f7fc
Show file tree
Hide file tree
Showing 34 changed files with 1,464 additions and 51 deletions.
102 changes: 102 additions & 0 deletions configs/rec/rec_mtb_nrtr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
Global:
use_gpu: True
epoch_num: 21
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/nrtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: EN_symbol
max_text_length: 25
infer_mode: False
use_space_char: True
save_res_path: ./output/rec/predicts_nrtr.txt

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.99
clip_norm: 5.0
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0.

Architecture:
model_type: rec
algorithm: NRTR
in_channels: 1
Transform:
Backbone:
name: MTB
cnn_num: 2
Head:
name: Transformer
d_model: 512
num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.


Loss:
name: NRTRLoss
smoothing: True

PostProcess:
name: NRTRLabelDecode

Metric:
name: RecMetric
main_indicator: acc

Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- NRTRDecodeImage: # load image
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 512
drop_last: True
num_workers: 8

Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- NRTRDecodeImage: # load image
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 1
use_shared_memory: False
19 changes: 10 additions & 9 deletions deploy/slim/prune/sensitivity_anal.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture'])

flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs before pruning: {flops}")
logger.info("FLOPs before pruning: {}".format(flops))

from paddleslim.dygraph import FPGMFilterPruner
model.train()
Expand Down Expand Up @@ -106,8 +106,8 @@ def main(config, device, logger, vdl_writer):

def eval_fn():
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}")
eval_class, False)
logger.info("metric['hmean']: {}".format(metric['hmean']))
return metric['hmean']

params_sensitive = pruner.sensitive(
Expand All @@ -123,16 +123,17 @@ def eval_fn():
# calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}")
logger.info("{}, {}".format(key, params_sensitive[key]))

#params_sensitive = {}
#for param in model.parameters():
# if 'transpose' not in param.name and 'linear' not in param.name:
# params_sensitive[param.name] = 0.1

plan = pruner.prune_vars(params_sensitive, [0])
for param in model.parameters():
if ("weights" in param.name and "conv" in param.name) or (
"w_0" in param.name and "conv2d" in param.name):
logger.info(f"{param.name}: {param.shape}")

flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs after pruning: {flops}")
logger.info("FLOPs after pruning: {}".format(flops))

# start train

Expand Down
2 changes: 2 additions & 0 deletions doc/doc_ch/algorithm_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] STAR-Net([paper](http:https://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))

参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:

Expand All @@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |


PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)
1 change: 1 addition & 0 deletions doc/doc_ch/recognition.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |

训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:

Expand Down
2 changes: 2 additions & 0 deletions doc/doc_en/algorithm_overview_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list:
- [x] STAR-Net([paper](http:https://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))

Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:

Expand All @@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |

Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
2 changes: 1 addition & 1 deletion doc/doc_en/recognition_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |

| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |

For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
Expand Down
Binary file modified doc/table/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/table/table.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
}

SUPPORT_DET_MODEL = ['DB']
VERSION = '2.2'
VERSION = '2.2.0.1'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")

Expand Down
2 changes: 1 addition & 1 deletion ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop

from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import *
Expand Down
28 changes: 28 additions & 0 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,34 @@ def encode(self, text):
return text_list


class NRTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN_symbol',
use_space_char=False,
**kwargs):

super(NRTRLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
text.insert(0, 2)
text.append(3)
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
return dict_character

class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

Expand Down
32 changes: 32 additions & 0 deletions ppocr/data/imaug/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,38 @@ def __call__(self, data):
return data


class NRTRDecodeImage(object):
""" decode image """

def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first

def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')

img = cv2.imdecode(img, 1)

if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data

class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
Expand Down
21 changes: 20 additions & 1 deletion ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cv2
import numpy as np
import random

from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort


Expand All @@ -43,6 +43,25 @@ def __call__(self, data):
return data


class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type

def __call__(self, data):
img = data['image']
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data


class RecResizeImg(object):
def __init__(self,
image_shape,
Expand Down
5 changes: 3 additions & 2 deletions ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss

from .rec_nrtr_loss import NRTRLoss
# cls loss
from .cls_loss import ClsLoss

Expand All @@ -44,8 +44,9 @@
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
]

config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
Expand Down
30 changes: 30 additions & 0 deletions ppocr/losses/rec_nrtr_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import paddle
from paddle import nn
import paddle.nn.functional as F


class NRTRLoss(nn.Layer):
def __init__(self, smoothing=True, **kwargs):
super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing

def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max()
tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
return {'loss': loss}
1 change: 1 addition & 0 deletions ppocr/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0

1 change: 0 additions & 1 deletion ppocr/modeling/architectures/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
Expand Down
3 changes: 2 additions & 1 deletion ppocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
support_dict = [
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
Expand Down
Loading

0 comments on commit 960f7fc

Please sign in to comment.