Skip to content

Commit

Permalink
add save_dir to args
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jun 3, 2021
1 parent 0bf30fe commit 2046605
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
def parse_args():
parser = utility.init_args()

# params for output
parser.add_argument("--table_output", type=str, default='output/table')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_max_text_length", type=int, default=100)
Expand Down Expand Up @@ -65,9 +67,9 @@ def __call__(self, img):
layout_res = self.table_layout(copy.deepcopy(img))
for region in layout_res:
x1, y1, x2, y2 = region['bbox']
roi_img = ori_im[y1:y2, x1:x2,:]
roi_img = ori_im[y1:y2, x1:x2, :]
if region['label'] == 'table':
res = self.table_system(roi_img)
res = self.text_system(roi_img)
else:
res = self.text_system(roi_img)
region['res'] = res
Expand All @@ -77,15 +79,15 @@ def __call__(self, img):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
excel_save_folder = 'output/table'
os.makedirs(excel_save_folder, exist_ok=True)
save_folder = args.table_output
os.makedirs(save_folder, exist_ok=True)

text_sys = OCRSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
imgname = os.path.basename(image_file).split('.')[0]
img_name = os.path.basename(image_file).split('.')[0]
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
if not flag:
img = cv2.imread(image_file)
Expand All @@ -95,11 +97,17 @@ def main(args):
starttime = time.time()
res = text_sys(img)

excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['label'] == 'table':
# x1, y1, x2, y2 = region['bbox']
excel_path = os.path.join(excel_save_folder, '{}_{}.xlsx'.format(imgname,region['bbox']))
to_excel(region['res'],excel_path)
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
else:
with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
logger.info(res)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
Expand Down

0 comments on commit 2046605

Please sign in to comment.