Skip to content

Commit

Permalink
Fix grid_sample data type bug when use fp16 (PaddlePaddle#9930)
Browse files Browse the repository at this point in the history
* fix gris_sample data type bug when use fp16

* fix gris_sample data type bug when use fp16

* fix v4rec batchsize
  • Loading branch information
Topdu committed May 15, 2023
1 parent 24ff4de commit 4251664
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 3 deletions.
4 changes: 2 additions & 2 deletions configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Global:
save_epoch_step: 10
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
Expand Down Expand Up @@ -101,7 +101,7 @@ Train:
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 128
first_bs: &bs 192
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
Expand Down
8 changes: 8 additions & 0 deletions ppocr/modeling/transforms/gaspin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,5 +280,13 @@ def forward(self, x, return_weight=False):

x = self.sp_net(x, sp_weight, offsets, lambda_color)
if self.stn:
is_fp16 = False
if build_P_prime_reshape.dtype != paddle.float32:
data_type = build_P_prime_reshape.dtype
x = x.cast(paddle.float32)
build_P_prime_reshape = build_P_prime_reshape.cast(paddle.float32)
is_fp16 = True
x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
if is_fp16:
x = x.cast(data_type)
return x
9 changes: 9 additions & 0 deletions ppocr/modeling/transforms/tps.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,14 @@ def forward(self, image):
batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:])
batch_P_prime = batch_P_prime.reshape(
[-1, image.shape[2], image.shape[3], 2])
is_fp16 = False
if batch_P_prime.dtype != paddle.float32:
data_type = batch_P_prime.dtype
image = image.cast(paddle.float32)
batch_P_prime = batch_P_prime.cast(paddle.float32)
is_fp16 = True
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
if is_fp16:
batch_I_r = batch_I_r.cast(data_type)

return batch_I_r
16 changes: 16 additions & 0 deletions ppocr/modeling/transforms/tps_spatial_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,28 @@

def grid_sample(input, grid, canvas=None):
input.stop_gradient = False

is_fp16 = False
if grid.dtype != paddle.float32:
data_type = grid.dtype
input = input.cast(paddle.float32)
grid = grid.cast(paddle.float32)
is_fp16 = True
output = F.grid_sample(input, grid)
if is_fp16:
output = output.cast(data_type)
grid = grid.cast(data_type)

if canvas is None:
return output
else:
input_mask = paddle.ones(shape=input.shape)
if is_fp16:
input_mask = input_mask.cast(paddle.float32)
grid = grid.cast(paddle.float32)
output_mask = F.grid_sample(input_mask, grid)
if is_fp16:
output_mask = output_mask.cast(data_type)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output

Expand Down
2 changes: 1 addition & 1 deletion tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def export_single_model(model,
shape=[None] + infer_shape, dtype="float32")
])

if arch_config["Backbone"]["name"] == "LCNetv3":
if arch_config["Backbone"]["name"] == "PPLCNetV3":
# for rep lcnetv3
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
Expand Down

0 comments on commit 4251664

Please sign in to comment.