From 9606bec16a7d7b8a1abcfa113a69d40f837b4cc5 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Mon, 17 Oct 2022 07:41:36 +0000 Subject: [PATCH] fix visionlan default dict --- doc/doc_ch/algorithm_rec_visionlan.md | 2 +- doc/doc_en/algorithm_rec_visionlan_en.md | 2 +- ppocr/data/imaug/label_ops.py | 9 +++++++-- ppocr/postprocess/rec_postprocess.py | 10 ++++++++-- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/doc/doc_ch/algorithm_rec_visionlan.md b/doc/doc_ch/algorithm_rec_visionlan.md index df039491d4..84b5ef6821 100644 --- a/doc/doc_ch/algorithm_rec_visionlan.md +++ b/doc/doc_ch/algorithm_rec_visionlan.md @@ -139,7 +139,7 @@ Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493) ## 5. FAQ 1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN) 。 -2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。 +2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练,预训练模型配套字典为'ppocr/utils/ic15_dict.txt'。 ## 引用 diff --git a/doc/doc_en/algorithm_rec_visionlan_en.md b/doc/doc_en/algorithm_rec_visionlan_en.md index 70c2ccc470..cf2293b3d0 100644 --- a/doc/doc_en/algorithm_rec_visionlan_en.md +++ b/doc/doc_en/algorithm_rec_visionlan_en.md @@ -120,7 +120,7 @@ Not supported ## 5. FAQ 1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN). -2. We use the pre-trained model provided by the VisionLAN authors for finetune training. +2. We use the pre-trained model provided by the VisionLAN authors for finetune training. The dictionary for the pre-trained model is 'ppocr/utils/ic15_dict.txt'. ## Citation diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 2a2ac2decd..511471c76b 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -107,6 +107,7 @@ def __init__(self, self.beg_str = "sos" self.end_str = "eos" self.lower = lower + self.use_default_dict = False if character_dict_path is None: logger = get_logger() @@ -116,8 +117,11 @@ def __init__(self, self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) self.lower = True + self.use_default_dict = True else: self.character_str = [] + if 'ppocr/utils/ic15_dict.txt' in character_dict_path: + self.use_default_dict = True with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: @@ -1400,8 +1404,9 @@ def __init__(self, **kwargs): super(VLLabelEncode, self).__init__( max_text_length, character_dict_path, use_space_char, lower) - self.character = self.character[10:] + self.character[ - 1:10] + [self.character[0]] + if self.use_default_dict: + self.character = self.character[10:] + self.character[ + 1:10] + [self.character[0]] self.dict = {} for i, char in enumerate(self.character): self.dict[char] = i diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 59b5254e48..98753ef7a8 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -26,10 +26,15 @@ def __init__(self, character_dict_path=None, use_space_char=False): self.end_str = "eos" self.reverse = False self.character_str = [] + self.use_default_dict = False + if character_dict_path is None: self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) + self.use_default_dict = True else: + if 'ppocr/utils/ic15_dict.txt' in character_dict_path: + self.use_default_dict = True with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: @@ -805,8 +810,9 @@ def __init__(self, character_dict_path=None, use_space_char=False, super(VLLabelDecode, self).__init__(character_dict_path, use_space_char) self.max_text_length = kwargs.get('max_text_length', 25) self.nclass = len(self.character) + 1 - self.character = self.character[10:] + self.character[ - 1:10] + [self.character[0]] + if self.use_default_dict: + self.character = self.character[10:] + self.character[ + 1:10] + [self.character[0]] def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """ convert text-index into text-label. """