Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7835 from qipengh/add_mlu_dbnet
Browse files Browse the repository at this point in the history
[MLU]adapt mlu device for running dbnet network
  • Loading branch information
MissPenguin committed Oct 10, 2022
2 parents 823a839 + 7851977 commit a706908
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
1 change: 1 addition & 0 deletions configs/det/det_mv3_db.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Global:
use_gpu: true
use_xpu: false
use_mlu: false
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 10
Expand Down
10 changes: 8 additions & 2 deletions tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def merge_config(config, opts):
return config


def check_device(use_gpu, use_xpu=False, use_npu=False):
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
Expand All @@ -137,6 +137,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False):
if use_npu and not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
if use_mlu and not paddle.device.is_compiled_with_mlu():
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
sys.exit(1)
except Exception as e:
pass

Expand Down Expand Up @@ -618,6 +621,7 @@ def preprocess(is_train=False):
use_gpu = config['Global'].get('use_gpu', False)
use_xpu = config['Global'].get('use_xpu', False)
use_npu = config['Global'].get('use_npu', False)
use_mlu = config['Global'].get('use_mlu', False)

alg = config['Architecture']['algorithm']
assert alg in [
Expand All @@ -632,10 +636,12 @@ def preprocess(is_train=False):
device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
elif use_npu:
device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
elif use_mlu:
device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
else:
device = 'gpu:{}'.format(dist.ParallelEnv()
.dev_id) if use_gpu else 'cpu'
check_device(use_gpu, use_xpu, use_npu)
check_device(use_gpu, use_xpu, use_npu, use_mlu)

device = paddle.set_device(device)

Expand Down
9 changes: 5 additions & 4 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ def main(config, device, logger, vdl_writer):
amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
Expand Down

0 comments on commit a706908

Please sign in to comment.