Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/MissPenguin/PaddleOCR in…
Browse files Browse the repository at this point in the history
…to develop
  • Loading branch information
dyning committed Jun 23, 2020
2 parents 160bb06 + b0171a7 commit 2401626
Show file tree
Hide file tree
Showing 33 changed files with 17,080 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ppocr/data/det/db_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ class DBProcessTest(object):
def __init__(self, params):
super(DBProcessTest, self).__init__()
self.resize_type = 0
if 'det_image_shape' in params:
self.image_shape = params['det_image_shape']
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
Expand Down
35 changes: 29 additions & 6 deletions ppocr/data/det/east_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,17 +455,23 @@ def __call__(self, label_infor):
class EASTProcessTest(object):
def __init__(self, params):
super(EASTProcessTest, self).__init__()
self.resize_type = 0
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
self.max_side_len = params['max_side_len']
else:
self.max_side_len = 2400

def resize_image(self, im):
def resize_image_type0(self, im):
"""
resize image to a size multiple of 32 which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
max_side_len = self.max_side_len
h, w, _ = im.shape
Expand Down Expand Up @@ -495,13 +501,30 @@ def resize_image(self, im):
resize_w = 32
else:
resize_w = (resize_w // 32 - 1) * 32
im = cv2.resize(im, (int(resize_w), int(resize_h)))
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
except:
print(im.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)

def resize_image_type1(self, im):
resize_h, resize_w = self.image_shape
ori_h, ori_w = im.shape[:2] # (h, w, c)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
return im, (ratio_h, ratio_w)

def __call__(self, im):
im, (ratio_h, ratio_w) = self.resize_image(im)
if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im)
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im[:, :, ::-1].astype(np.float32)
Expand Down
3 changes: 2 additions & 1 deletion ppocr/modeling/heads/det_east_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle.fluid as fluid
from ..common_functions import conv_bn_layer, deconv_bn_layer
from collections import OrderedDict


class EASTHead(object):
Expand Down Expand Up @@ -110,7 +111,7 @@ def detector_header(self, f_common):
def __call__(self, inputs):
f_common = self.unet_fusion(inputs)
f_score, f_geo = self.detector_header(f_common)
predicts = {}
predicts = OrderedDict()
predicts['f_score'] = f_score
predicts['f_geo'] = f_geo
return predicts
10 changes: 9 additions & 1 deletion ppocr/postprocess/east_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from .locality_aware_nms import nms_locality
import cv2

import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import lanms


class EASTPostPocess(object):
"""
Expand Down Expand Up @@ -66,7 +73,8 @@ def detect(self,
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
# boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
if boxes.shape[0] == 0:
return []
# Here we filter some low score boxes by the average score map,
Expand Down
1 change: 1 addition & 0 deletions ppocr/postprocess/lanms/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
adaptor.so
140 changes: 140 additions & 0 deletions ppocr/postprocess/lanms/.ycm_extra_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#!/usr/bin/env python
#
# Copyright (C) 2014 Google Inc.
#
# This file is part of YouCompleteMe.
#
# YouCompleteMe is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# YouCompleteMe is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with YouCompleteMe. If not, see <http:https://www.gnu.org/licenses/>.

import os
import sys
import glob
import ycm_core

# These are the compilation flags that will be used in case there's no
# compilation database set (by default, one is not set).
# CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR.
sys.path.append(os.path.dirname(__file__))


BASE_DIR = os.path.dirname(os.path.realpath(__file__))

from plumbum.cmd import python_config


flags = [
'-Wall',
'-Wextra',
'-Wnon-virtual-dtor',
'-Winvalid-pch',
'-Wno-unused-local-typedefs',
'-std=c++11',
'-x', 'c++',
'-Iinclude',
] + python_config('--cflags').split()


# Set this to the absolute path to the folder (NOT the file!) containing the
# compile_commands.json file to use that instead of 'flags'. See here for
# more details: http:https://clang.llvm.org/docs/JSONCompilationDatabase.html
#
# Most projects will NOT need to set this to anything; you can just change the
# 'flags' list of compilation flags.
compilation_database_folder = ''

if os.path.exists( compilation_database_folder ):
database = ycm_core.CompilationDatabase( compilation_database_folder )
else:
database = None

SOURCE_EXTENSIONS = [ '.cpp', '.cxx', '.cc', '.c', '.m', '.mm' ]

def DirectoryOfThisScript():
return os.path.dirname( os.path.abspath( __file__ ) )


def MakeRelativePathsInFlagsAbsolute( flags, working_directory ):
if not working_directory:
return list( flags )
new_flags = []
make_next_absolute = False
path_flags = [ '-isystem', '-I', '-iquote', '--sysroot=' ]
for flag in flags:
new_flag = flag

if make_next_absolute:
make_next_absolute = False
if not flag.startswith( '/' ):
new_flag = os.path.join( working_directory, flag )

for path_flag in path_flags:
if flag == path_flag:
make_next_absolute = True
break

if flag.startswith( path_flag ):
path = flag[ len( path_flag ): ]
new_flag = path_flag + os.path.join( working_directory, path )
break

if new_flag:
new_flags.append( new_flag )
return new_flags


def IsHeaderFile( filename ):
extension = os.path.splitext( filename )[ 1 ]
return extension in [ '.h', '.hxx', '.hpp', '.hh' ]


def GetCompilationInfoForFile( filename ):
# The compilation_commands.json file generated by CMake does not have entries
# for header files. So we do our best by asking the db for flags for a
# corresponding source file, if any. If one exists, the flags for that file
# should be good enough.
if IsHeaderFile( filename ):
basename = os.path.splitext( filename )[ 0 ]
for extension in SOURCE_EXTENSIONS:
replacement_file = basename + extension
if os.path.exists( replacement_file ):
compilation_info = database.GetCompilationInfoForFile(
replacement_file )
if compilation_info.compiler_flags_:
return compilation_info
return None
return database.GetCompilationInfoForFile( filename )


# This is the entry point; this function is called by ycmd to produce flags for
# a file.
def FlagsForFile( filename, **kwargs ):
if database:
# Bear in mind that compilation_info.compiler_flags_ does NOT return a
# python list, but a "list-like" StringVec object
compilation_info = GetCompilationInfoForFile( filename )
if not compilation_info:
return None

final_flags = MakeRelativePathsInFlagsAbsolute(
compilation_info.compiler_flags_,
compilation_info.compiler_working_dir_ )
else:
relative_to = DirectoryOfThisScript()
final_flags = MakeRelativePathsInFlagsAbsolute( flags, relative_to )

return {
'flags': final_flags,
'do_cache': True
}

13 changes: 13 additions & 0 deletions ppocr/postprocess/lanms/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags)
LDFLAGS = $(shell python3-config --ldflags)

DEPS = lanms.h $(shell find include -xtype f)
CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp

LIB_SO = adaptor.so

$(LIB_SO): $(CXX_SOURCES) $(DEPS)
$(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC

clean:
rm -rf $(LIB_SO)
20 changes: 20 additions & 0 deletions ppocr/postprocess/lanms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import subprocess
import os
import numpy as np

BASE_DIR = os.path.dirname(os.path.realpath(__file__))

if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value
raise RuntimeError('Cannot compile lanms: {}'.format(BASE_DIR))


def merge_quadrangle_n9(polys, thres=0.3, precision=10000):
from .adaptor import merge_quadrangle_n9 as nms_impl
if len(polys) == 0:
return np.array([], dtype='float32')
p = polys.copy()
p[:,:8] *= precision
ret = np.array(nms_impl(p, thres), dtype='float32')
ret[:,:8] /= precision
return ret

10 changes: 10 additions & 0 deletions ppocr/postprocess/lanms/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np


from . import merge_quadrangle_n9

if __name__ == '__main__':
# unit square with confidence 1
q = np.array([0, 0, 0, 1, 1, 1, 1, 0, 1], dtype='float32')

print(merge_quadrangle_n9(np.array([q, q + 0.1, q + 2])))
61 changes: 61 additions & 0 deletions ppocr/postprocess/lanms/adaptor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"

#include "lanms.h"

namespace py = pybind11;


namespace lanms_adaptor {

std::vector<std::vector<float>> polys2floats(const std::vector<lanms::Polygon> &polys) {
std::vector<std::vector<float>> ret;
for (size_t i = 0; i < polys.size(); i ++) {
auto &p = polys[i];
auto &poly = p.poly;
ret.emplace_back(std::vector<float>{
float(poly[0].X), float(poly[0].Y),
float(poly[1].X), float(poly[1].Y),
float(poly[2].X), float(poly[2].Y),
float(poly[3].X), float(poly[3].Y),
float(p.score),
});
}

return ret;
}


/**
*
* \param quad_n9 an n-by-9 numpy array, where first 8 numbers denote the
* quadrangle, and the last one is the score
* \param iou_threshold two quadrangles with iou score above this threshold
* will be merged
*
* \return an n-by-9 numpy array, the merged quadrangles
*/
std::vector<std::vector<float>> merge_quadrangle_n9(
py::array_t<float, py::array::c_style | py::array::forcecast> quad_n9,
float iou_threshold) {
auto pbuf = quad_n9.request();
if (pbuf.ndim != 2 || pbuf.shape[1] != 9)
throw std::runtime_error("quadrangles must have a shape of (n, 9)");
auto n = pbuf.shape[0];
auto ptr = static_cast<float *>(pbuf.ptr);
return polys2floats(lanms::merge_quadrangle_n9(ptr, n, iou_threshold));
}

}

PYBIND11_PLUGIN(adaptor) {
py::module m("adaptor", "NMS");

m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9,
"merge quadrangels");

return m.ptr();
}

Loading

0 comments on commit 2401626

Please sign in to comment.