forked from PaddlePaddle/PaddleOCR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add satrn * 修复satrn导出问题 * 规范satrn config文件 * 删除SATRNRecResizeImg --------- Co-authored-by: zhiminzhang0830 <[email protected]>
- Loading branch information
1 parent
3ded601
commit 30201ef
Showing
13 changed files
with
976 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
Global: | ||
use_gpu: true | ||
epoch_num: 5 | ||
log_smooth_window: 20 | ||
print_batch_step: 50 | ||
save_model_dir: ./output/rec/rec_satrn/ | ||
save_epoch_step: 1 | ||
# evaluation is run every 5000 iterations | ||
eval_batch_step: [0, 5000] | ||
cal_metric_during_train: False | ||
pretrained_model: | ||
checkpoints: | ||
save_inference_dir: | ||
use_visualdl: False | ||
infer_img: | ||
# for data or label process | ||
character_dict_path: ppocr/utils/dict90.txt | ||
max_text_length: 25 | ||
infer_mode: False | ||
use_space_char: False | ||
rm_symbol: True | ||
save_res_path: ./output/rec/predicts_satrn.txt | ||
|
||
Optimizer: | ||
name: Adam | ||
beta1: 0.9 | ||
beta2: 0.999 | ||
lr: | ||
name: Piecewise | ||
decay_epochs: [3, 4] | ||
values: [0.0003, 0.00003, 0.000003] | ||
regularizer: | ||
name: 'L2' | ||
factor: 0 | ||
|
||
Architecture: | ||
model_type: rec | ||
algorithm: SATRN | ||
Backbone: | ||
name: ShallowCNN | ||
in_channels: 3 | ||
hidden_dim: 256 | ||
Head: | ||
name: SATRNHead | ||
enc_cfg: | ||
n_layers: 6 | ||
n_head: 8 | ||
d_k: 32 | ||
d_v: 32 | ||
d_model: 256 | ||
n_position: 100 | ||
d_inner: 1024 | ||
dropout: 0.1 | ||
dec_cfg: | ||
n_layers: 6 | ||
d_embedding: 256 | ||
n_head: 8 | ||
d_model: 256 | ||
d_inner: 1024 | ||
d_k: 32 | ||
d_v: 32 | ||
max_seq_len: 25 | ||
start_idx: 91 | ||
|
||
Loss: | ||
name: SATRNLoss | ||
|
||
PostProcess: | ||
name: SATRNLabelDecode | ||
|
||
Metric: | ||
name: RecMetric | ||
main_indicator: acc | ||
|
||
Train: | ||
dataset: | ||
name: LMDBDataSet | ||
data_dir: ./train_data/data_lmdb_release/training/ | ||
transforms: | ||
- DecodeImage: # load image | ||
img_mode: BGR | ||
channel_first: False | ||
- SATRNLabelEncode: # Class handling label | ||
- SVTRRecResizeImg: | ||
image_shape: [3, 32, 100] | ||
padding: False | ||
- KeepKeys: | ||
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order | ||
loader: | ||
shuffle: True | ||
batch_size_per_card: 128 | ||
drop_last: True | ||
num_workers: 8 | ||
use_shared_memory: False | ||
|
||
Eval: | ||
dataset: | ||
name: LMDBDataSet | ||
data_dir: ./train_data/data_lmdb_release/evaluation/ | ||
transforms: | ||
- DecodeImage: # load image | ||
img_mode: BGR | ||
channel_first: False | ||
- SATRNLabelEncode: # Class handling label | ||
- SVTRRecResizeImg: | ||
image_shape: [3, 32, 100] | ||
padding: False | ||
- KeepKeys: | ||
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order | ||
|
||
loader: | ||
shuffle: False | ||
drop_last: False | ||
batch_size_per_card: 128 | ||
num_workers: 4 | ||
use_shared_memory: False | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
This code is refer from: | ||
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/module_losses/ce_module_loss.py | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import paddle | ||
from paddle import nn | ||
|
||
|
||
class SATRNLoss(nn.Layer): | ||
def __init__(self, **kwargs): | ||
super(SATRNLoss, self).__init__() | ||
ignore_index = kwargs.get('ignore_index', 92) # 6626 | ||
self.loss_func = paddle.nn.loss.CrossEntropyLoss( | ||
reduction="none", ignore_index=ignore_index) | ||
|
||
def forward(self, predicts, batch): | ||
predict = predicts[:, : | ||
-1, :] # ignore last index of outputs to be in same seq_len with targets | ||
label = batch[1].astype( | ||
"int64")[:, 1:] # ignore first index of target in loss calculation | ||
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[ | ||
1], predict.shape[2] | ||
assert len(label.shape) == len(list(predict.shape)) - 1, \ | ||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]" | ||
|
||
inputs = paddle.reshape(predict, [-1, num_classes]) | ||
targets = paddle.reshape(label, [-1]) | ||
loss = self.loss_func(inputs, targets) | ||
return {'loss': loss.mean()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
This code is refer from: | ||
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/backbones/shallow_cnn.py | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import math | ||
import numpy as np | ||
import paddle | ||
from paddle import ParamAttr | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
from paddle.nn import MaxPool2D | ||
from paddle.nn.initializer import KaimingNormal, Uniform, Constant | ||
|
||
|
||
class ConvBNLayer(nn.Layer): | ||
def __init__(self, | ||
num_channels, | ||
filter_size, | ||
num_filters, | ||
stride, | ||
padding, | ||
num_groups=1): | ||
super(ConvBNLayer, self).__init__() | ||
|
||
self.conv = nn.Conv2D( | ||
in_channels=num_channels, | ||
out_channels=num_filters, | ||
kernel_size=filter_size, | ||
stride=stride, | ||
padding=padding, | ||
groups=num_groups, | ||
weight_attr=ParamAttr(initializer=KaimingNormal()), | ||
bias_attr=False) | ||
|
||
self.bn = nn.BatchNorm2D( | ||
num_filters, | ||
weight_attr=ParamAttr(initializer=Uniform(0, 1)), | ||
bias_attr=ParamAttr(initializer=Constant(0))) | ||
self.relu = nn.ReLU() | ||
|
||
def forward(self, inputs): | ||
y = self.conv(inputs) | ||
y = self.bn(y) | ||
y = self.relu(y) | ||
return y | ||
|
||
|
||
class ShallowCNN(nn.Layer): | ||
def __init__(self, in_channels=1, hidden_dim=512): | ||
super().__init__() | ||
assert isinstance(in_channels, int) | ||
assert isinstance(hidden_dim, int) | ||
|
||
self.conv1 = ConvBNLayer( | ||
in_channels, 3, hidden_dim // 2, stride=1, padding=1) | ||
self.conv2 = ConvBNLayer( | ||
hidden_dim // 2, 3, hidden_dim, stride=1, padding=1) | ||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) | ||
self.out_channels = hidden_dim | ||
|
||
def forward(self, x): | ||
|
||
x = self.conv1(x) | ||
x = self.pool(x) | ||
|
||
x = self.conv2(x) | ||
x = self.pool(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.