Skip to content

Commit

Permalink
fix visionlan default dict
Browse files Browse the repository at this point in the history
  • Loading branch information
andyjiang1116 committed Oct 17, 2022
1 parent 7cbb1f4 commit 9606bec
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion doc/doc_ch/algorithm_rec_visionlan.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493)
## 5. FAQ

1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN)
2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。
2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练,预训练模型配套字典为'ppocr/utils/ic15_dict.txt'

## 引用

Expand Down
2 changes: 1 addition & 1 deletion doc/doc_en/algorithm_rec_visionlan_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Not supported
## 5. FAQ

1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN).
2. We use the pre-trained model provided by the VisionLAN authors for finetune training.
2. We use the pre-trained model provided by the VisionLAN authors for finetune training. The dictionary for the pre-trained model is 'ppocr/utils/ic15_dict.txt'.

## Citation

Expand Down
9 changes: 7 additions & 2 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self,
self.beg_str = "sos"
self.end_str = "eos"
self.lower = lower
self.use_default_dict = False

if character_dict_path is None:
logger = get_logger()
Expand All @@ -116,8 +117,11 @@ def __init__(self,
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
self.lower = True
self.use_default_dict = True
else:
self.character_str = []
if 'ppocr/utils/ic15_dict.txt' in character_dict_path:
self.use_default_dict = True
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
Expand Down Expand Up @@ -1400,8 +1404,9 @@ def __init__(self,
**kwargs):
super(VLLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char, lower)
self.character = self.character[10:] + self.character[
1:10] + [self.character[0]]
if self.use_default_dict:
self.character = self.character[10:] + self.character[
1:10] + [self.character[0]]
self.dict = {}
for i, char in enumerate(self.character):
self.dict[char] = i
Expand Down
10 changes: 8 additions & 2 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ def __init__(self, character_dict_path=None, use_space_char=False):
self.end_str = "eos"
self.reverse = False
self.character_str = []
self.use_default_dict = False

if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
self.use_default_dict = True
else:
if 'ppocr/utils/ic15_dict.txt' in character_dict_path:
self.use_default_dict = True
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
Expand Down Expand Up @@ -805,8 +810,9 @@ def __init__(self, character_dict_path=None, use_space_char=False,
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
self.nclass = len(self.character) + 1
self.character = self.character[10:] + self.character[
1:10] + [self.character[0]]
if self.use_default_dict:
self.character = self.character[10:] + self.character[
1:10] + [self.character[0]]

def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
Expand Down

0 comments on commit 9606bec

Please sign in to comment.