Skip to content

Commit

Permalink
add satrn (PaddlePaddle#8433)
Browse files Browse the repository at this point in the history
* add satrn

* 修复satrn导出问题

* 规范satrn config文件

* 删除SATRNRecResizeImg

---------

Co-authored-by: zhiminzhang0830 <[email protected]>
  • Loading branch information
zhiminzhang0830 and zhiminzhang0830 authored Feb 8, 2023
1 parent 3ded601 commit 30201ef
Show file tree
Hide file tree
Showing 13 changed files with 976 additions and 8 deletions.
117 changes: 117 additions & 0 deletions configs/rec/rec_satrn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
Global:
use_gpu: true
epoch_num: 5
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ./output/rec/rec_satrn/
save_epoch_step: 1
# evaluation is run every 5000 iterations
eval_batch_step: [0, 5000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: 25
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_satrn.txt

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.0003, 0.00003, 0.000003]
regularizer:
name: 'L2'
factor: 0

Architecture:
model_type: rec
algorithm: SATRN
Backbone:
name: ShallowCNN
in_channels: 3
hidden_dim: 256
Head:
name: SATRNHead
enc_cfg:
n_layers: 6
n_head: 8
d_k: 32
d_v: 32
d_model: 256
n_position: 100
d_inner: 1024
dropout: 0.1
dec_cfg:
n_layers: 6
d_embedding: 256
n_head: 8
d_model: 256
d_inner: 1024
d_k: 32
d_v: 32
max_seq_len: 25
start_idx: 91

Loss:
name: SATRNLoss

PostProcess:
name: SATRNLabelDecode

Metric:
name: RecMetric
main_indicator: acc

Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SATRNLabelEncode: # Class handling label
- SVTRRecResizeImg:
image_shape: [3, 32, 100]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 128
drop_last: True
num_workers: 8
use_shared_memory: False

Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SATRNLabelEncode: # Class handling label
- SVTRRecResizeImg:
image_shape: [3, 32, 100]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order

loader:
shuffle: False
drop_last: False
batch_size_per_card: 128
num_workers: 4
use_shared_memory: False

56 changes: 56 additions & 0 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,62 @@ def get_ignored_tokens(self):
return [self.padding_idx]


class SATRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
lower=False,
**kwargs):
super(SATRNLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.lower = lower

def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1

return dict_character

def encode(self, text):
if self.lower:
text = text.lower()
text_list = []
for char in text:
text_list.append(self.dict.get(char, self.unknown_idx))
if len(text_list) == 0:
return None
return text_list

def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
if len(target) > self.max_text_len:
padded_text = target[:self.max_text_len]
else:
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data

def get_ignored_tokens(self):
return [self.padding_idx]


class PRENLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
Expand Down
4 changes: 3 additions & 1 deletion ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss
from .rec_can_loss import CANLoss
from .rec_satrn_loss import SATRNLoss

# cls loss
from .cls_loss import ClsLoss
Expand Down Expand Up @@ -73,7 +74,8 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss'
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
'SATRNLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
Expand Down
46 changes: 46 additions & 0 deletions ppocr/losses/rec_satrn_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/module_losses/ce_module_loss.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
from paddle import nn


class SATRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SATRNLoss, self).__init__()
ignore_index = kwargs.get('ignore_index', 92) # 6626
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
reduction="none", ignore_index=ignore_index)

def forward(self, predicts, batch):
predict = predicts[:, :
-1, :] # ignore last index of outputs to be in same seq_len with targets
label = batch[1].astype(
"int64")[:, 1:] # ignore first index of target in loss calculation
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
1], predict.shape[2]
assert len(label.shape) == len(list(predict.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"

inputs = paddle.reshape(predict, [-1, num_classes])
targets = paddle.reshape(label, [-1])
loss = self.loss_func(inputs, targets)
return {'loss': loss.mean()}
3 changes: 2 additions & 1 deletion ppocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ def build_backbone(config, model_type):
from .rec_vitstr import ViTSTR
from .rec_resnet_rfl import ResNetRFL
from .rec_densenet import DenseNet
from .rec_shallow_cnn import ShallowCNN
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet'
'DenseNet', 'ShallowCNN'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
Expand Down
87 changes: 87 additions & 0 deletions ppocr/modeling/backbones/rec_shallow_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/backbones/shallow_cnn.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import MaxPool2D
from paddle.nn.initializer import KaimingNormal, Uniform, Constant


class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
num_groups=1):
super(ConvBNLayer, self).__init__()

self.conv = nn.Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)

self.bn = nn.BatchNorm2D(
num_filters,
weight_attr=ParamAttr(initializer=Uniform(0, 1)),
bias_attr=ParamAttr(initializer=Constant(0)))
self.relu = nn.ReLU()

def forward(self, inputs):
y = self.conv(inputs)
y = self.bn(y)
y = self.relu(y)
return y


class ShallowCNN(nn.Layer):
def __init__(self, in_channels=1, hidden_dim=512):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(hidden_dim, int)

self.conv1 = ConvBNLayer(
in_channels, 3, hidden_dim // 2, stride=1, padding=1)
self.conv2 = ConvBNLayer(
hidden_dim // 2, 3, hidden_dim, stride=1, padding=1)
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = hidden_dim

def forward(self, x):

x = self.conv1(x)
x = self.pool(x)

x = self.conv2(x)
x = self.pool(x)

return x
3 changes: 2 additions & 1 deletion ppocr/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def build_head(config):
from .rec_visionlan_head import VLHead
from .rec_rfl_head import RFLHead
from .rec_can_head import CANHead
from .rec_satrn_head import SATRNHead

# cls head
from .cls_head import ClsHead
Expand All @@ -56,7 +57,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
'DRRGHead', 'CANHead'
'DRRGHead', 'CANHead', 'SATRNHead'
]

if config['name'] == 'DRRGHead':
Expand Down
Loading

0 comments on commit 30201ef

Please sign in to comment.