Skip to content

Commit

Permalink
matmul use fp32 compute_type (PaddlePaddle#8733)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhang Ting committed Dec 29, 2022
1 parent bdfa1d2 commit 4f735db
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,10 @@ def main(config, device, logger, vdl_writer):
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_gemm_use_half_precision_compute_type': 0,
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False)
Expand Down

0 comments on commit 4f735db

Please sign in to comment.