Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
LDOUBLEV committed Feb 1, 2021
2 parents f896032 + 2a0c3d4 commit 56cbbdf
Show file tree
Hide file tree
Showing 34 changed files with 1,675 additions and 101 deletions.
38 changes: 23 additions & 15 deletions PPOCRLabel/PPOCRLabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def format_shape(s):

for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
if trans_dic["label"] is "" and mode == 'Auto':
if trans_dic["label"] == "" and mode == 'Auto':
continue
shapes.append(trans_dic)

Expand Down Expand Up @@ -1450,7 +1450,7 @@ def importDirImages(self, dirpath, isDelete = False):
item = QListWidgetItem(closeicon, filename)
self.fileListWidget.addItem(item)

print('dirPath in importDirImages is', dirpath)
print('DirPath in importDirImages is', dirpath)
self.iconlist.clear()
self.additems5(dirpath)
self.changeFileFolder = True
Expand All @@ -1459,7 +1459,6 @@ def importDirImages(self, dirpath, isDelete = False):
self.reRecogButton.setEnabled(True)
self.actions.AutoRec.setEnabled(True)
self.actions.reRec.setEnabled(True)
self.actions.saveLabel.setEnabled(True)


def openPrevImg(self, _value=False):
Expand Down Expand Up @@ -1764,7 +1763,7 @@ def reRecognition(self):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
if result[0][0] is not '':
if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
self.result_dic.append(result)
Expand Down Expand Up @@ -1795,7 +1794,7 @@ def singleRerecognition(self):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
if result[0][0] is not '':
if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
if result[1][0] == shape.label:
Expand Down Expand Up @@ -1862,6 +1861,8 @@ def loadFilestate(self, saveDir):
for each in states:
file, state = each.split('\t')
self.fileStatedict[file] = 1
self.actions.saveLabel.setEnabled(True)
self.actions.saveRec.setEnabled(True)


def saveFilestate(self):
Expand Down Expand Up @@ -1919,22 +1920,29 @@ def saveRecResult(self):

rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt'
crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/'
ques_img = []
if not os.path.exists(crop_img_dir):
os.mkdir(crop_img_dir)

with open(rec_gt_dir, 'w', encoding='utf-8') as f:
for key in self.fileStatedict:
idx = self.getImglabelidx(key)
for i, label in enumerate(self.PPlabel[idx]):
if label['difficult']: continue
try:
img = cv2.imread(key)
img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
cv2.imwrite(crop_img_dir+img_name, img_crop)
f.write('crop_img/'+ img_name + '\t')
f.write(label['transcription'] + '\n')

QMessageBox.information(self, "Information", "Cropped images has been saved in "+str(crop_img_dir))
for i, label in enumerate(self.PPlabel[idx]):
if label['difficult']: continue
img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
cv2.imwrite(crop_img_dir+img_name, img_crop)
f.write('crop_img/'+ img_name + '\t')
f.write(label['transcription'] + '\n')
except Exception as e:
ques_img.append(key)
print("Can not read image ",e)
if ques_img:
QMessageBox.information(self, "Information", "The following images can not be saved, "
"please check the image path and labels.\n" + "".join(str(i)+'\n' for i in ques_img))
QMessageBox.information(self, "Information", "Cropped images have been saved in "+str(crop_img_dir))

def speedChoose(self):
if self.labelDialogOption.isChecked():
Expand Down Expand Up @@ -1991,7 +1999,7 @@ def main():
resource_file = './libs/resources.py'
if not os.path.exists(resource_file):
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
"directory resources.py "
import libs.resources
sys.exit(main())
6 changes: 3 additions & 3 deletions configs/rec/rec_mv3_none_bilstm_ctc.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Global:
use_gpu: true
use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
Expand Down Expand Up @@ -59,7 +59,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -78,7 +78,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_mv3_none_none_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -77,7 +77,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_mv3_tps_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -82,7 +82,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_r34_vd_none_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -77,7 +77,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_r34_vd_none_none_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -75,7 +75,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -81,7 +81,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
107 changes: 107 additions & 0 deletions configs/rec/rec_r50_fpn_srn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
Global:
use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/rec/srn_new
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
num_heads: 8
infer_mode: False
use_space_char: False


Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 10.0
lr:
learning_rate: 0.0001

Architecture:
model_type: rec
algorithm: SRN
in_channels: 1
Transform:
Backbone:
name: ResNetFPN
Head:
name: SRNHead
max_text_length: 25
num_heads: 8
num_encoder_TUs: 2
num_decoder_TUs: 4
hidden_dims: 512

Loss:
name: SRNLoss

PostProcess:
name: SRNLabelDecode

Metric:
name: RecMetric
main_indicator: acc

Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/srn_train_data_duiqi
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 64
drop_last: False
num_workers: 4

Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 32
num_workers: 4
3 changes: 2 additions & 1 deletion doc/doc_ch/algorithm_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [x] STAR-Net([paper](http:https://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]

参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:

Expand All @@ -53,5 +53,6 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |

PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)
21 changes: 17 additions & 4 deletions doc/doc_ch/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ inference 模型(`paddle.jit.save`保存的模型)
- [三、文本识别模型推理](#文本识别模型推理)
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
- [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [4. 多语言模型的推理](#多语言模型的推理)
- [3. 基于SRN损失的识别模型推理](#基于SRN损失的识别模型推理)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理)

- [四、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
Expand Down Expand Up @@ -295,16 +296,28 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
<a name="基于SRN损失的识别模型推理"></a>
### 3. 基于SRN损失的识别模型推理
基于SRN损失的识别模型,需要额外设置识别算法参数 --rec_algorithm="SRN"。
同时需要保证预测shape与训练时一致,如: --rec_image_shape="1, 64, 256"

### 3. 自定义文本识别字典的推理
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
--rec_model_dir="./inference/srn/" \
--rec_image_shape="1, 64, 256" \
--rec_char_type="en" \
--rec_algorithm="SRN"
```

### 4. 自定义文本识别字典的推理
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`

```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
```

<a name="多语言模型的推理"></a>
### 4. 多语言模型的推理
### 5. 多语言模型的推理
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:

Expand Down
2 changes: 2 additions & 0 deletions doc/doc_ch/recognition.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
* 数据下载

若您本地没有数据集,可以在官网下载 [icdar2015](http:https://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。

<a name="自定义数据集"></a>
* 使用自己数据集
Expand Down Expand Up @@ -200,6 +201,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |

训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:

Expand Down
Loading

0 comments on commit 56cbbdf

Please sign in to comment.