Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7143 from WenmuZhou/tttt
Browse files Browse the repository at this point in the history
fix bug in amp eval
  • Loading branch information
andyjiang1116 committed Aug 9, 2022
2 parents 6445362 + 3bed2e1 commit f5692c3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Global:
# evaluation is run every 835 iterations
eval_batch_step: [0, 4000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
Expand Down
4 changes: 2 additions & 2 deletions test_tipc/configs/layoutxlm_ser/train_infer_python.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
Expand All @@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o
norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
quant_export:
fpgm_export:
distill_export:null
Expand Down
28 changes: 21 additions & 7 deletions tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def train(config,
post_process_class,
eval_class,
model_type,
extra_input=extra_input)
extra_input=extra_input,
scaler=scaler)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
Expand Down Expand Up @@ -462,7 +463,8 @@ def eval(model,
post_process_class,
eval_class,
model_type=None,
extra_input=False):
extra_input=False,
scaler=None):
model.eval()
with paddle.no_grad():
total_frame = 0.0
Expand All @@ -479,12 +481,24 @@ def eval(model,
break
images = batch[0]
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)

# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
else:
preds = model(images)
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)

batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
Expand Down

0 comments on commit f5692c3

Please sign in to comment.