Skip to content

Commit

Permalink
fix inference in tps
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 committed Jun 3, 2020
1 parent b722eb5 commit be3a164
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 7 deletions.
4 changes: 3 additions & 1 deletion configs/rec/rec_mv3_tps_bilstm_attn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ Global:
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: attention
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
checkpoints:
Expand Down
1 change: 1 addition & 0 deletions configs/rec/rec_mv3_tps_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Global:
max_text_length: 25
character_type: en
loss_type: ctc
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
checkpoints:
Expand Down
5 changes: 4 additions & 1 deletion ppocr/data/rec/dataset_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(self, params):
self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length']
self.mode = params['mode']
if "tps" in params:
self.tps = True
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
self.drop_last = params['drop_last']
Expand Down Expand Up @@ -109,7 +111,8 @@ def sample_iter_reader():
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops)
char_ops=self.char_ops,
tps=self.tps)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
Expand Down
9 changes: 7 additions & 2 deletions ppocr/data/rec/img_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,16 @@ def process_image(img,
label=None,
char_ops=None,
loss_type=None,
max_text_length=None):
max_text_length=None,
tps=None):
if char_ops.character_type == "en":
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
if tps:
image_shape = [3, 32, 320]
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :]
if label is not None:
char_num = char_ops.get_char_num()
Expand Down
14 changes: 13 additions & 1 deletion ppocr/modeling/architectures/rec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, params):
global_params = params['Global']
char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num
self.char_type = global_params['character_type']
if "TPS" in params:
tps_params = deepcopy(params["TPS"])
tps_params.update(global_params)
Expand Down Expand Up @@ -60,8 +61,8 @@ def __init__(self, params):
def create_feed(self, mode):
image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention":
label_in = fluid.data(
name='label_in',
Expand All @@ -86,6 +87,17 @@ def create_feed(self, mode):
use_double_buffer=True,
iterable=False)
else:
if self.char_type == "ch":
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set default shape=[3,32,320], it may affect the inference effect"
)
image_shape[-1] = 320
image = fluid.data(
name='image', shape=image_shape, dtype='float32')
labels = None
loader = None
return image, labels, loader
Expand Down
4 changes: 2 additions & 2 deletions tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __call__(self, img_list):
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
#todo: why index has 2 offset
#attenton index has 2 offset: beg and end
preds = preds - 2
preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score])
Expand All @@ -138,7 +138,7 @@ def __call__(self, img_list):
except:
logger.info(
"ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n"
"Please set --rec_image_shape=input_shape and --rec_char_type='ch' ")
"Please set --rec_image_shape=input_shape and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
Expand Down

0 comments on commit be3a164

Please sign in to comment.