Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add satrn #8433

Merged
merged 4 commits into from
Feb 8, 2023
Merged

add satrn #8433

merged 4 commits into from
Feb 8, 2023

Conversation

zhiminzhang0830
Copy link
Contributor

复现论文:On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention
参考代码:https://github.com/open-mmlab/mmocr/blob/1.x/configs/textrecog/satrn/README.md

@paddle-bot
Copy link

paddle-bot bot commented Nov 24, 2022

Thanks for your contribution!

@zhiminzhang0830
Copy link
Contributor Author

zhiminzhang0830 commented Nov 24, 2022

数据集:
训练集:https://aistudio.baidu.com/aistudio/datasetdetail/166485
验证集:https://aistudio.baidu.com/aistudio/datasetdetail/182867
实验结果:
IIIK-3000:94.53,SVT:91.04,IC13:94.68,IC15:78.24,SVTP:83.72,CUTE80:86.11,Avg:88.05

模型训练:
python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_satrn.yml

模型验证:
python tools/eval.py -c {your config file}
-o Global.pretrained_model={your model file}
Eval.dataset.data_dir={your dataset path}/IIIT5k_3000

模型测试:
python3 tools/infer_rec.py -c {your config file}
-o Global.pretrained_model={your model file}
Global.infer_img="doc/imgs_words_en/"

存在问题:
1.使用python tools/export_model.py导出模型时,速度非常慢,大概需要7分钟才能导出模型
2.使用导出的模型做推理的时候报错,报错信息如下:
[libprotobuf ERROR /paddle/build/third_party/protobuf/src/extern_protobuf/src/google/protobuf/io/coded_stream.cc:208] A protocol message was rejected because it was too big (more than 67108864 bytes). To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf ERROR /paddle/build/third_party/protobuf/src/extern_protobuf/src/google/protobuf/io/coded_stream.cc:208] A protocol message was rejected because it was too big (more than 67108864 bytes). To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
Traceback (most recent call last):
File "tools/infer/predict_rec.py", line 690, in
main(utility.parse_args())
File "tools/infer/predict_rec.py", line 652, in main
text_recognizer = TextRecognizer(args)
File "tools/infer/predict_rec.py", line 127, in init
utility.create_predictor(args, 'rec', logger)
File "/data/code/PaddleOCR_satrn/tools/infer/utility.py", line 277, in create_predictor
predictor = inference.create_predictor(config)
ValueError: (InvalidArgument) Failed to parse program_desc from binary string.
[Hint: Expected desc_.ParseFromString(binary_str) == true, but received desc_.ParseFromString(binary_str):0 != true:1.] (at /paddle/paddle/fluid/framework/program_desc.cc:103)

@zhiminzhang0830
Copy link
Contributor Author

@Topdu
Copy link
Collaborator

Topdu commented Jan 31, 2023

satrn_head.py的526行修改为:for step in range(0, paddle.to_tensor(self.max_seq_len)):
这样可以解决导出inference model 慢的问题呢,
推理时修改predict_rec.py 449行为: elif self.rec_algorithm in ["SVTR", "SATRN"]:
推理的结果是:
Predicts of ./doc/imgs_words_en/word_19.png:('slowuknuknuknuknuknuknuknuknuknuknukniuknuknuknuknuknuknuknuknukn', 0.5304282307624817)
看结果分析,后处理似乎没有找到eos。
推理命令:
python tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_19.png' --rec_model_dir='./inference/satrn/' --rec_algorithm='SATRN' --rec_image_shape='3,32,100' --rec_char_dict_path='./ppocr/utils/dict90.txt'

@Topdu
Copy link
Collaborator

Topdu commented Jan 31, 2023

satrn 和 nrtr是同类型的识别模型,如果可以的话尽量复用nrtr的代码,例如shallow cnn可以写到MTB中,attention和encoder layer和decoder layer 如果差别不大的话也可以复用nrtr的代码

init_target_seq[:, 0] = self.start_idx

outputs = []
for step in range(0, self.max_seq_len):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for step in range(0, paddle.to_tensor(self.max_seq_len)):

@zhiminzhang0830
Copy link
Contributor Author

导出模型:
python3 tools/export_model.py -c configs/rec/rec_satrn.yml
-o Global.pretrained_model=inference/satrn/rec_satrn/best_accuracy.pdparams
Global.save_inference_dir=./inference/satrn/
模型推理:
python tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_19.png' --rec_model_dir='./inference/satrn/' --rec_algorithm='SATRN' --rec_image_shape='3,32,100' --rec_char_dict_path='./ppocr/utils/dict90.txt' --use_space_char='False'

epoch_num: 5
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ../work_dir/ppocr/satrn_branch/rec_satrn/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_model_dir最好与其他方法修改一致:./output/rec/rec_satrn

Train:
dataset:
name: LMDBDataSet
data_dir: /data/Dataset/OCR_Rec/visual_data/rfl_dataset2/training
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

./train_data/data_lmdb_release/training/

Eval:
dataset:
name: LMDBDataSet
data_dir: /data/Dataset/OCR_Rec/visual_data/rfl_dataset2/evaluation_academic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

./train_data/data_lmdb_release/evaluation/

@@ -465,6 +465,21 @@ def __call__(self, data):
return data


class SATRNRecResizeImg(object):
def __init__(self, image_shape, padding=True, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没用到SATRNRecResizeImg的话可以删除

Copy link
Collaborator

@tink2123 tink2123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if mask is not None:
attn = masked_fill(attn, mask == 0, -1e9)
# attn = attn.masked_fill(mask == 0, float('-inf'))
# attn += mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo:不必要的注释可以删除

@tink2123
Copy link
Collaborator

tink2123 commented Feb 8, 2023

需要补充文档并接入TIPC

@tink2123 tink2123 merged commit 30201ef into PaddlePaddle:dygraph Feb 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants