Skip to content

Commit

Permalink
add center loss cod and cfg (#4165)
Browse files Browse the repository at this point in the history
* add center loss cod and cfg

* fix name
  • Loading branch information
littletomatodonkey committed Sep 29, 2021
1 parent 5613e21 commit 27b6346
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 6 deletions.
126 changes: 126 additions & 0 deletions configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_mobile_pp-OCRv2_enhanced_ctc_loss
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_mobile_pp-OCRv2_enhanced_ctc_loss.txt


Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700, 800]
values : [0.001, 0.0001]
warmup_epoch: 5
regularizer:
name: L2
factor: 2.0e-05


Architecture:
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
return_feats: true

Loss:
name: CombinedLoss
loss_config_list:
- CTCLoss:
use_focal_loss: false
weight: 1.0
- CenterLoss:
weight: 0.05
num_classes: 6625
feat_dim: 96
init_center: false
center_file_path: "./train_center.pkl"
# you can also try to add ace loss on your own dataset
# - ACELoss:
# weight: 0.1

PostProcess:
name: CTCLabelDecode

Metric:
name: RecMetric
main_indicator: acc

Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
- label_ace
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 8
5 changes: 5 additions & 0 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ def __call__(self, data):
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)

label = [0] * len(self.character)
for x in text:
label[x] += 1
data['label_ace'] = np.array(label)
return data

def add_special_char(self, dict_character):
Expand Down
1 change: 0 additions & 1 deletion ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def build_loss(config):
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss', 'AsterLoss'
]

config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
Expand Down
50 changes: 50 additions & 0 deletions ppocr/losses/ace_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn


class ACELoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_func = nn.CrossEntropyLoss(
weight=None,
ignore_index=0,
reduction='none',
soft_label=True,
axis=-1)

def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')

predicts = nn.functional.softmax(predicts, axis=-1)
aggregation_preds = paddle.sum(predicts, axis=1)
aggregation_preds = paddle.divide(aggregation_preds, div)

length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)

batch = paddle.divide(batch, div)

loss = self.loss_func(aggregation_preds, batch)

return {"loss_ace": loss}
89 changes: 89 additions & 0 deletions ppocr/losses/center_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""

def __init__(self,
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None):
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center

if init_center:
assert os.path.exists(
center_file_path
), f"center path({center_file_path}) must exist when init_center is set as True."
with open(center_file_path, 'rb') as f:
char_dict = pickle.load(f)
for key in char_dict.keys():
self.centers[key] = paddle.to_tensor(char_dict[key])

def __call__(self, predicts, batch):
assert isinstance(predicts, (list, tuple))
features, predicts = predicts

feats_reshape = paddle.reshape(
features, [-1, features.shape[-1]]).astype("float64")
label = paddle.argmax(predicts, axis=2)
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])

batch_size = feats_reshape.shape[0]

#calc feat * feat
dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])

#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])

#first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp

#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
label = paddle.expand(
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
label).astype("float64") #get mask
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss}
4 changes: 4 additions & 0 deletions ppocr/losses/combined_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import paddle
import paddle.nn as nn

from .rec_ctc_loss import CTCLoss
from .center_loss import CenterLoss
from .ace_loss import ACELoss

from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
Expand Down
10 changes: 9 additions & 1 deletion ppocr/losses/rec_ctc_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,24 @@


class CTCLoss(nn.Layer):
def __init__(self, **kwargs):
def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
self.use_focal_loss = use_focal_loss

def forward(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64')
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
if self.use_focal_loss:
weight = paddle.exp(-loss)
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
weight = paddle.square(weight) * self.focal_loss_alpha
loss = paddle.multiply(loss, weight)
loss = loss.mean() # sum
return {'loss': loss}
17 changes: 13 additions & 4 deletions ppocr/modeling/heads/rec_ctc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self,
out_channels,
fc_decay=0.0004,
mid_channels=None,
return_feats=False,
**kwargs):
super(CTCHead, self).__init__()
if mid_channels is None:
Expand Down Expand Up @@ -66,14 +67,22 @@ def __init__(self,
bias_attr=bias_attr2)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats

def forward(self, x, targets=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
predicts = self.fc1(x)
predicts = self.fc2(predicts)

x = self.fc1(x)
predicts = self.fc2(x)

if self.return_feats:
result = (x, predicts)
else:
result = predicts

if not self.training:
predicts = F.softmax(predicts, axis=2)
return predicts
result = predicts

return result
2 changes: 2 additions & 0 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def __init__(self,
character_type, use_space_char)

def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
Expand Down

0 comments on commit 27b6346

Please sign in to comment.