Skip to content

Commit

Permalink
fix sr_telescope (PaddlePaddle#10004)
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 committed May 23, 2023
1 parent 2c0664b commit 096fd27
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 30 deletions.
2 changes: 1 addition & 1 deletion ppocr/modeling/heads/sr_rensnet_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
def forward(self, query, key, value, mask=None, attention_map=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = query.shape[0]
nbatches = paddle.shape(query)[0]

query, key, value = \
[paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
Expand Down
75 changes: 47 additions & 28 deletions ppocr/modeling/transforms/tbsrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,24 @@ def positionalencoding2d(d_model, height, width):
pe = paddle.zeros([d_model, height, width])
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = paddle.exp(paddle.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
div_term = paddle.exp(
paddle.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
pos_w = paddle.arange(0., width, dtype='float32').unsqueeze(1)
pos_h = paddle.arange(0., height, dtype='float32').unsqueeze(1)

pe[0:d_model:2, :, :] = paddle.sin(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
pe[1:d_model:2, :, :] = paddle.cos(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
pe[d_model::2, :, :] = paddle.sin(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
pe[d_model + 1::2, :, :] = paddle.cos(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
pe[0:d_model:2, :, :] = paddle.sin(pos_w * div_term).transpose(
[1, 0]).unsqueeze(1).tile([1, height, 1])
pe[1:d_model:2, :, :] = paddle.cos(pos_w * div_term).transpose(
[1, 0]).unsqueeze(1).tile([1, height, 1])
pe[d_model::2, :, :] = paddle.sin(pos_h * div_term).transpose(
[1, 0]).unsqueeze(2).tile([1, 1, width])
pe[d_model + 1::2, :, :] = paddle.cos(pos_h * div_term).transpose(
[1, 0]).unsqueeze(2).tile([1, 1, width])

return pe


class FeatureEnhancer(nn.Layer):

def __init__(self):
super(FeatureEnhancer, self).__init__()

Expand All @@ -77,13 +80,16 @@ def forward(self, conv_feature):
global_info: (batch, embedding_size, 1, 1)
conv_feature: (batch, channel, H, W)
'''
batch = conv_feature.shape[0]
position2d = positionalencoding2d(64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
batch = paddle.shape(conv_feature)[0]
position2d = positionalencoding2d(
64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
position2d = position2d.tile([batch, 1, 1])
conv_feature = paddle.concat([conv_feature, position2d], 1) # batch, 128(64+64), 32, 128
conv_feature = paddle.concat([conv_feature, position2d],
1) # batch, 128(64+64), 32, 128
result = conv_feature.transpose([0, 2, 1])
origin_result = result
result = self.mul_layernorm1(origin_result + self.multihead(result, result, result, mask=None)[0])
result = self.mul_layernorm1(origin_result + self.multihead(
result, result, result, mask=None)[0])
origin_result = result
result = self.mul_layernorm3(origin_result + self.pff(result))
result = self.linear(result)
Expand Down Expand Up @@ -124,23 +130,35 @@ def __init__(self,
assert math.log(scale_factor, 2) % 1 == 0
upsample_block_num = int(math.log(scale_factor, 2))
self.block1 = nn.Sequential(
nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4),
nn.Conv2D(
in_planes, 2 * hidden_units, kernel_size=9, padding=4),
nn.PReLU()
# nn.ReLU()
)
self.srb_nums = srb_nums
for i in range(srb_nums):
setattr(self, 'block%d' % (i + 2), RecurrentResidualBlock(2 * hidden_units))

setattr(self, 'block%d' % (srb_nums + 2),
nn.Sequential(
nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1),
nn.BatchNorm2D(2 * hidden_units)
))
setattr(self, 'block%d' % (i + 2),
RecurrentResidualBlock(2 * hidden_units))

setattr(
self,
'block%d' % (srb_nums + 2),
nn.Sequential(
nn.Conv2D(
2 * hidden_units,
2 * hidden_units,
kernel_size=3,
padding=1),
nn.BatchNorm2D(2 * hidden_units)))

# self.non_local = NonLocalBlock2D(64, 64)
block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)]
block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4))
block_ = [
UpsampleBLock(2 * hidden_units, 2)
for _ in range(upsample_block_num)
]
block_.append(
nn.Conv2D(
2 * hidden_units, in_planes, kernel_size=9, padding=4))
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
self.tps_inputsize = [height // scale_factor, width // scale_factor]
tps_outputsize = [height // scale_factor, width // scale_factor]
Expand All @@ -164,7 +182,8 @@ def __init__(self,
self.english_dict = {}
for index in range(len(self.english_alphabet)):
self.english_dict[self.english_alphabet[index]] = index
transformer = Transformer(alphabet='-0123456789abcdefghijklmnopqrstuvwxyz')
transformer = Transformer(
alphabet='-0123456789abcdefghijklmnopqrstuvwxyz')
self.transformer = transformer
for param in self.transformer.parameters():
param.trainable = False
Expand Down Expand Up @@ -219,10 +238,10 @@ def forward(self, x):
# add transformer
label = [str_filt(i, 'lower') + '-' for i in x[2]]
length_tensor, input_tensor, text_gt = self.label_encoder(label)
hr_pred, word_attention_map_gt, hr_correct_list = self.transformer(hr_img, length_tensor,
input_tensor)
sr_pred, word_attention_map_pred, sr_correct_list = self.transformer(sr_img, length_tensor,
input_tensor)
hr_pred, word_attention_map_gt, hr_correct_list = self.transformer(
hr_img, length_tensor, input_tensor)
sr_pred, word_attention_map_pred, sr_correct_list = self.transformer(
sr_img, length_tensor, input_tensor)
output["hr_img"] = hr_img
output["hr_pred"] = hr_pred
output["text_gt"] = text_gt
Expand Down Expand Up @@ -257,8 +276,8 @@ def forward(self, x):
residual = self.conv2(residual)
residual = self.bn2(residual)

size = residual.shape
size = paddle.shape(residual)
residual = residual.reshape([size[0], size[1], -1])
residual = self.feature_enhancer(residual)
residual = residual.reshape([size[0], size[1], size[2], size[3]])
return x + residual
return x + residual
3 changes: 2 additions & 1 deletion tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def export_single_model(model,
shape=[None] + infer_shape, dtype="float32")
])

if arch_config["Backbone"]["name"] == "PPLCNetV3":
if arch_config["model_type"] != "sr" and arch_config["Backbone"][
"name"] == "PPLCNetV3":
# for rep lcnetv3
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
Expand Down

0 comments on commit 096fd27

Please sign in to comment.