Skip to content

Commit

Permalink
improve amp training (PaddlePaddle#10119)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 committed Jun 8, 2023
1 parent 062e2c5 commit 6949448
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
2 changes: 1 addition & 1 deletion configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Global:
save_res_path: ./output/det_db/predicts_db.txt
use_amp: False
amp_level: O2
amp_custom_black_list: ['exp']
amp_dtype: bfloat16

Architecture:
name: DistillationModel
Expand Down
17 changes: 12 additions & 5 deletions tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def train(config,
scaler=None,
amp_level='O2',
amp_custom_black_list=[],
amp_custom_white_list=[]):
amp_custom_white_list=[],
amp_dtype='float16'):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
Expand Down Expand Up @@ -279,7 +280,8 @@ def train(config,
with paddle.amp.auto_cast(
level=amp_level,
custom_black_list=amp_custom_black_list,
custom_white_list=amp_custom_white_list):
custom_white_list=amp_custom_white_list,
dtype=amp_dtype):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
Expand Down Expand Up @@ -393,7 +395,9 @@ def train(config,
extra_input=extra_input,
scaler=scaler,
amp_level=amp_level,
amp_custom_black_list=amp_custom_black_list)
amp_custom_black_list=amp_custom_black_list,
amp_custom_white_list=amp_custom_white_list,
amp_dtype=amp_dtype)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
Expand Down Expand Up @@ -486,7 +490,9 @@ def eval(model,
extra_input=False,
scaler=None,
amp_level='O2',
amp_custom_black_list=[]):
amp_custom_black_list=[],
amp_custom_white_list=[],
amp_dtype='float16'):
model.eval()
with paddle.no_grad():
total_frame = 0.0
Expand All @@ -509,7 +515,8 @@ def eval(model,
if scaler:
with paddle.amp.auto_cast(
level=amp_level,
custom_black_list=amp_custom_black_list):
custom_black_list=amp_custom_black_list,
dtype=amp_dtype):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
Expand Down
7 changes: 5 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def main(config, device, logger, vdl_writer):

use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2')
amp_dtype = config["Global"].get("amp_dtype", 'float16')
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
amp_custom_white_list = config['Global'].get('amp_custom_white_list', [])
if use_amp:
Expand All @@ -181,7 +182,8 @@ def main(config, device, logger, vdl_writer):
models=model,
optimizers=optimizer,
level=amp_level,
master_weight=True)
master_weight=True,
dtype=amp_dtype)
else:
scaler = None

Expand All @@ -195,7 +197,8 @@ def main(config, device, logger, vdl_writer):
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
amp_level, amp_custom_black_list, amp_custom_white_list)
amp_level, amp_custom_black_list, amp_custom_white_list,
amp_dtype)


def test_reader(config, device, logger):
Expand Down

0 comments on commit 6949448

Please sign in to comment.