Skip to content

Commit

Permalink
add pad for small image in det
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jun 8, 2021
1 parent 48eba02 commit dec76eb
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 11 deletions.
12 changes: 10 additions & 2 deletions ppocr/data/imaug/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __call__(self, data):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
img.astype('float32') * self.scale - self.mean) / self.std
return data


Expand Down Expand Up @@ -122,6 +122,8 @@ def __init__(self, **kwargs):
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
self.pad = kwargs.get('pad', False)
self.pad_size = kwargs.get('pad_size', 480)
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
Expand Down Expand Up @@ -163,7 +165,7 @@ def resize_image_type0(self, img):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, _ = img.shape
h, w, c = img.shape

# limit the max side
if self.limit_type == 'max':
Expand All @@ -172,6 +174,8 @@ def resize_image_type0(self, img):
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
elif self.pad:
ratio = float(self.pad_size) / max(h, w)
else:
ratio = 1.
else:
Expand All @@ -197,6 +201,10 @@ def resize_image_type0(self, img):
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
if self.limit_type == 'max' and self.pad:
padding_im = np.zeros((self.pad_size, self.pad_size, c), dtype=np.float32)
padding_im[:resize_h, :resize_w, :] = img
img = padding_im
return img, [ratio_h, ratio_w]

def resize_image_type2(self, img):
Expand Down
11 changes: 5 additions & 6 deletions ppocr/postprocess/db_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def __init__(self,
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])

def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
def boxes_from_bitmap(self, pred, _bitmap, shape):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''

dest_height, dest_width, ratio_h, ratio_w = shape
bitmap = _bitmap
height, width = bitmap.shape

Expand Down Expand Up @@ -89,9 +89,9 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
box = np.array(box)

box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
np.round(box[:, 0] / ratio_w), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
np.round(box[:, 1] / ratio_h), 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
Expand Down Expand Up @@ -175,15 +175,14 @@ def __call__(self, outs_dict, shape_list):

boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
shape_list[batch_index])

boxes_batch.append({'points': boxes})
return boxes_batch
5 changes: 3 additions & 2 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@

class OCRSystem(object):
def __init__(self, args):
args.det_pad = True
args.det_pad_size = 640
self.text_system = TextSystem(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp:https://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu,thread_num=args.cpu_threads)
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score

Expand All @@ -67,7 +69,6 @@ def __call__(self, img):
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list


def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
Expand Down
4 changes: 3 additions & 1 deletion tools/infer/predict_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self, args):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
'limit_type': args.det_limit_type,
'pad':args.det_pad,
'pad_size':args.det_pad_size
}
}, {
'NormalizeImage': {
Expand Down
2 changes: 2 additions & 0 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def init_args():
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--det_pad", type=str2bool, default=False)
parser.add_argument("--det_pad_size", type=int, default=640)

# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
Expand Down

0 comments on commit dec76eb

Please sign in to comment.