Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7840 from LDOUBLEV/dygraph
Browse files Browse the repository at this point in the history
add polygon params
  • Loading branch information
LDOUBLEV committed Oct 12, 2022
2 parents 34174d4 + 3628ac1 commit 9df7730
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
1 change: 1 addition & 0 deletions configs/det/det_r50_db++_icdar15.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ PostProcess:
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
det_box_type: 'quad' # 'quad' or 'poly'
Metric:
name: DetMetric
main_indicator: hmean
Expand Down
1 change: 1 addition & 0 deletions configs/det/det_r50_db++_td_tr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ PostProcess:
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.5
det_box_type: 'quad' # 'quad' or 'poly'
Metric:
name: DetMetric
main_indicator: hmean
Expand Down
14 changes: 8 additions & 6 deletions ppocr/postprocess/db_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def __init__(self,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
use_polygon=False,
box_type='quad',
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.use_polygon = use_polygon
self.box_type = box_type
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
Expand Down Expand Up @@ -233,12 +233,14 @@ def __call__(self, outs_dict, shape_list):
self.dilation_kernel)
else:
mask = segmentation[batch_index]
if self.use_polygon is True:
if self.box_type == 'poly':
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
else:
elif self.box_type == 'quad':
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
else:
raise ValueError("box_type can only be one of ['quad', 'poly']")

boxes_batch.append({'points': boxes})
return boxes_batch
Expand All @@ -254,7 +256,7 @@ def __init__(self,
unclip_ratio=1.5,
use_dilation=False,
score_mode="fast",
use_polygon=False,
box_type='quad',
**kwargs):
self.model_name = model_name
self.key = key
Expand All @@ -265,7 +267,7 @@ def __init__(self,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
score_mode=score_mode,
use_polygon=use_polygon)
box_type=box_type)

def __call__(self, predicts, shape_list):
results = {}
Expand Down
22 changes: 13 additions & 9 deletions tools/infer/predict_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, args):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["box_type"] = args.det_box_type
elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
Expand All @@ -75,6 +76,7 @@ def __init__(self, args):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["box_type"] = args.det_box_type
pre_process_list[1] = {
'NormalizeImage': {
'std': [1.0, 1.0, 1.0],
Expand All @@ -98,23 +100,23 @@ def __init__(self, args):
postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
self.det_sast_polygon = args.det_sast_polygon
if self.det_sast_polygon:

if args.det_box_type == 'poly':
postprocess_params["sample_pts_num"] = 6
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
else:
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3

elif self.det_algorithm == "PSE":
postprocess_params['name'] = 'PSEPostProcess'
postprocess_params["thresh"] = args.det_pse_thresh
postprocess_params["box_thresh"] = args.det_pse_box_thresh
postprocess_params["min_area"] = args.det_pse_min_area
postprocess_params["box_type"] = args.det_pse_box_type
postprocess_params["box_type"] = args.det_box_type
postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type
elif self.det_algorithm == "FCE":
pre_process_list[0] = {
'DetResizeForTest': {
Expand All @@ -126,7 +128,7 @@ def __init__(self, args):
postprocess_params["alpha"] = args.alpha
postprocess_params["beta"] = args.beta
postprocess_params["fourier_degree"] = args.fourier_degree
postprocess_params["box_type"] = args.det_fce_box_type
postprocess_params["box_type"] = args.det_box_type
elif self.det_algorithm == "CT":
pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
postprocess_params['name'] = 'CTPostProcess'
Expand Down Expand Up @@ -190,6 +192,8 @@ def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if type(box) is list:
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
Expand All @@ -204,6 +208,8 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if type(box) is list:
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
Expand Down Expand Up @@ -262,12 +268,10 @@ def __call__(self, img):
else:
raise NotImplementedError

#self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
self.det_algorithm in ["PSE", "FCE", "CT"] and
self.postprocess_op.box_type == 'poly'):

if self.args.det_box_type == 'poly':
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
Expand Down
5 changes: 2 additions & 3 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def init_args():
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--det_box_type", type=str, default='quad')

# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
Expand All @@ -58,6 +59,7 @@ def init_args():
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=str2bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")

# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
Expand All @@ -66,21 +68,18 @@ def init_args():
# SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=str2bool, default=False)

# PSE parmas
parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='quad')
parser.add_argument("--det_pse_scale", type=int, default=1)

# FCE parmas
parser.add_argument("--scales", type=list, default=[8, 16, 32])
parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=1.0)
parser.add_argument("--fourier_degree", type=int, default=5)
parser.add_argument("--det_fce_box_type", type=str, default='poly')

# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
Expand Down

0 comments on commit 9df7730

Please sign in to comment.