Skip to content

Commit

Permalink
fix gap between table structure train model and inference model (Padd…
Browse files Browse the repository at this point in the history
…lePaddle#4565)

* add indent in pipeline_rpc_client.py

* fix gap in table structure train model and inference model
  • Loading branch information
WenmuZhou committed Nov 10, 2021
1 parent a896002 commit b6a2141
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
17 changes: 9 additions & 8 deletions configs/table/table_mv3.yml
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
Global:
use_gpu: true
epoch_num: 50
epoch_num: 400
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 5
save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
cal_metric_during_train: True
pretrained_model:
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
infer_img: doc/table/table.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 500
max_elem_length: 800
max_cell_num: 500
infer_mode: False
process_total_num: 0
process_cut_num: 0


Optimizer:
name: Adam
beta1: 0.9
Expand All @@ -41,13 +40,15 @@ Architecture:
Backbone:
name: MobileNetV3
scale: 1.0
model_name: small
disable_se: True
model_name: large
Head:
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500

Loss:
name: TableAttentionLoss
Expand Down
6 changes: 3 additions & 3 deletions deploy/pdserving/pipeline_rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ def cv2_to_base64(image):
image_data = file.read()
image = cv2_to_base64(image_data)

for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["res"])
print(ret)
for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["res"])
print(ret)
40 changes: 24 additions & 16 deletions ppocr/modeling/heads/table_att_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,40 @@


class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
def __init__(self,
in_channels,
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num

self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type
self.in_max_len = in_max_len

if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)

def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
Expand All @@ -60,16 +68,16 @@ def forward(self, inputs, targets=None):
if len(fea.shape) == 3:
pass
else:
last_shape = int(np.prod(fea.shape[2:])) # gry added
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]

hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_elem_length+1):
for i in range(self.max_elem_length + 1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
Expand All @@ -96,7 +104,7 @@ def forward(self, inputs, targets=None):
alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0
while i < max_elem_length+1:
while i < max_elem_length + 1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
Expand All @@ -105,7 +113,7 @@ def forward(self, inputs, targets=None):
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1

output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
Expand All @@ -119,9 +127,9 @@ def forward(self, inputs, targets=None):
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}



class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
Expand Down

0 comments on commit b6a2141

Please sign in to comment.