Skip to content

Commit

Permalink
improve system prediction and remove some hard code (#4643)
Browse files Browse the repository at this point in the history
* fix center yaml

* rm init_center param

* fix typo

* improve pred system
  • Loading branch information
littletomatodonkey committed Nov 17, 2021
1 parent 1c2c269 commit 7dccfe5
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ Loss:
weight: 0.05
num_classes: 6625
feat_dim: 96
init_center: false
center_file_path: "./train_center.pkl"
center_file_path:
# you can also try to add ace loss on your own dataset
# - ACELoss:
# weight: 0.1
Expand Down
4 changes: 2 additions & 2 deletions doc/doc_ch/models_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训

|模型名称|模型简介|配置文件|推理模型大小|下载地址|
| --- | --- | --- | --- | --- |
|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_det|【最新】原始超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)|
|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_det|【最新】原始超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)|
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)| 2.6M |[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)|
|ch_ppocr_mobile_v2.0_det|原始超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)|
|ch_ppocr_server_v2.0_det|通用模型,支持中英文、多语种文本检测,比超轻量模型更大,但效果更好|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)|
Expand Down
4 changes: 2 additions & 2 deletions doc/doc_en/models_list_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ Relationship of the above models is as follows.

|model name|description|config|model size|download|
| --- | --- | --- | --- | --- |
|ch_PP-OCRv2_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_det|[New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)|
|ch_PP-OCRv2_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_det|[New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)|
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|2.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)|
|ch_ppocr_mobile_v2.0_det|Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)|
|ch_ppocr_server_v2.0_det|General model, which is larger than the lightweight model, but achieved better performance|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)|
Expand Down
10 changes: 3 additions & 7 deletions ppocr/losses/center_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,17 @@ class CenterLoss(nn.Layer):
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""

def __init__(self,
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None):
def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype("float64")

if init_center:
if center_file_path is not None:
assert os.path.exists(
center_file_path
), f"center path({center_file_path}) must exist when init_center is set as True."
), f"center path({center_file_path}) must exist when it is not None."
with open(center_file_path, 'rb') as f:
char_dict = pickle.load(f)
for key in char_dict.keys():
Expand Down
36 changes: 22 additions & 14 deletions tools/infer/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,19 @@ def __init__(self, args):
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)

def print_draw_crop_rec_res(self, img_crop_list, rec_res):
self.args = args
self.crop_image_res_index = 0

def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
os.makedirs(output_dir, exist_ok=True)
bbox_num = len(img_crop_list)
for bno in range(bbox_num):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
logger.info(bno, rec_res[bno])
cv2.imwrite(
os.path.join(output_dir,
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
img_crop_list[bno])
logger.debug(f"{bno}, {rec_res[bno]}")
self.crop_image_res_index += bbox_num

def __call__(self, img, cls=True):
ori_im = img.copy()
Expand All @@ -80,7 +88,9 @@ def __call__(self, img, cls=True):
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
Expand Down Expand Up @@ -135,17 +145,17 @@ def main(args):
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
logger.debug("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
total_time += elapse

logger.info(
logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
logger.info("{}, {:.3f}".format(text, score))
logger.debug("{}, {:.3f}".format(text, score))

if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
Expand All @@ -160,19 +170,17 @@ def main(args):
scores,
drop_score=drop_score,
font_path=font_path)
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
draw_img_save_dir = args.draw_img_save_dir
os.makedirs(draw_img_save_dir, exist_ok=True)
if flag:
image_file = image_file[:-3] + "png"
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
os.path.join(draw_img_save_dir, os.path.basename(image_file)),
draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(image_file))))

logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
if args.benchmark:
text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report()
Expand Down
8 changes: 7 additions & 1 deletion tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def init_args():
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=True)
parser.add_argument("--warmup", type=str2bool, default=False)

#
parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")

# multi-process
parser.add_argument("--use_mp", type=str2bool, default=False)
Expand Down

0 comments on commit 7dccfe5

Please sign in to comment.