Skip to content

Commit

Permalink
V4Rec code pr (PaddlePaddle#9725)
Browse files Browse the repository at this point in the history
* v4rec code

* v4rec add nrtrloss

* Add V4rec backbone file

* Add V4Rec config file.

* Fix V4rec reparameters when export_model

* convert lvnetv3

* fix codestyle

* fix infer_rec v4rec
  • Loading branch information
Topdu committed Apr 19, 2023
1 parent 385a1f9 commit 43abe2f
Show file tree
Hide file tree
Showing 15 changed files with 787 additions and 71 deletions.
131 changes: 131 additions & 0 deletions configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v4
save_epoch_step: 10
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3.txt


Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 5
regularizer:
name: L2
factor: 3.0e-05


Architecture:
model_type: rec
algorithm: SVTR_LCNet
Transform:
Backbone:
name: LCNetv3
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length

Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:

PostProcess:
name: CTCLabelDecode

Metric:
name: RecMetric
main_indicator: acc

Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
21 changes: 15 additions & 6 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,27 +1241,36 @@ def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
gtc_encode=None,
**kwargs):
super(MultiLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)

self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
self.gtc_encode_type = gtc_encode
if gtc_encode is None:
self.gtc_encode = SARLabelEncode(
max_text_length, character_dict_path, use_space_char, **kwargs)
else:
self.gtc_encode = eval(gtc_encode)(
max_text_length, character_dict_path, use_space_char, **kwargs)

def __call__(self, data):
data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data)
data_gtc = copy.deepcopy(data)
data_out = dict()
data_out['img_path'] = data.get('img_path', None)
data_out['image'] = data['image']
ctc = self.ctc_encode.__call__(data_ctc)
sar = self.sar_encode.__call__(data_sar)
if ctc is None or sar is None:
gtc = self.gtc_encode.__call__(data_gtc)
if ctc is None or gtc is None:
return None
data_out['label_ctc'] = ctc['label']
data_out['label_sar'] = sar['label']
if self.gtc_encode_type is not None:
data_out['label_gtc'] = gtc['label']
else:
data_out['label_sar'] = gtc['label']
data_out['length'] = ctc['length']
return data_out

Expand Down
3 changes: 2 additions & 1 deletion ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .rec_rfl_loss import RFLLoss
from .rec_can_loss import CANLoss
from .rec_satrn_loss import SATRNLoss
from .rec_nrtr_loss import NRTRLoss

# cls loss
from .cls_loss import ClsLoss
Expand Down Expand Up @@ -75,7 +76,7 @@ def build_loss(config):
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
'SATRNLoss'
'SATRNLoss', 'NRTRLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
Expand Down
5 changes: 4 additions & 1 deletion ppocr/losses/rec_multi_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .rec_nrtr_loss import NRTRLoss


class MultiLoss(nn.Layer):
Expand All @@ -30,7 +31,6 @@ def __init__(self, **kwargs):
self.loss_list = kwargs.pop('loss_config_list')
self.weight_1 = kwargs.get('weight_1', 1.0)
self.weight_2 = kwargs.get('weight_2', 1.0)
self.gtc_loss = kwargs.get('gtc_loss', 'sar')
for loss_info in self.loss_list:
for name, param in loss_info.items():
if param is not None:
Expand All @@ -49,6 +49,9 @@ def forward(self, predicts, batch):
elif name == 'SARLoss':
loss = loss_func(predicts['sar'],
batch[:1] + batch[2:])['loss'] * self.weight_2
elif name == 'NRTRLoss':
loss = loss_func(predicts['nrtr'],
batch[:1] + batch[2:])['loss'] * self.weight_2
else:
raise NotImplementedError(
'{} is not supported in MultiLoss yet'.format(name))
Expand Down
32 changes: 32 additions & 0 deletions ppocr/losses/rec_nrtr_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import paddle
from paddle import nn
import paddle.nn.functional as F


class NRTRLoss(nn.Layer):
def __init__(self, smoothing=True, ignore_index=0, **kwargs):
super(NRTRLoss, self).__init__()
if ignore_index >= 0 and not smoothing:
self.loss_func = nn.CrossEntropyLoss(
reduction='mean', ignore_index=ignore_index)
self.smoothing = smoothing

def forward(self, pred, batch):
max_len = batch[2].max()
tgt = batch[1][:, 1:2 + max_len]
pred = pred.reshape([-1, pred.shape[2]])
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype=tgt.dtype))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
return {'loss': loss}
3 changes: 2 additions & 1 deletion ppocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def build_backbone(config, model_type):
from .rec_resnet_rfl import ResNetRFL
from .rec_densenet import DenseNet
from .rec_shallow_cnn import ShallowCNN
from .rec_lcnetv3 import LCNetv3
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet', 'ShallowCNN'
'DenseNet', 'ShallowCNN', 'LCNetv3'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
Expand Down
Loading

0 comments on commit 43abe2f

Please sign in to comment.