diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml new file mode 100644 index 0000000000..635c392d70 --- /dev/null +++ b/configs/rec/rec_mtb_nrtr.yml @@ -0,0 +1,102 @@ +Global: + use_gpu: True + epoch_num: 21 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/nrtr/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: EN_symbol + max_text_length: 25 + infer_mode: False + use_space_char: True + save_res_path: ./output/rec/predicts_nrtr.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.99 + clip_norm: 5.0 + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0. + +Architecture: + model_type: rec + algorithm: NRTR + in_channels: 1 + Transform: + Backbone: + name: MTB + cnn_num: 2 + Head: + name: Transformer + d_model: 512 + num_encoder_layers: 6 + beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. + + +Loss: + name: NRTRLoss + smoothing: True + +PostProcess: + name: NRTRLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - NRTRLabelEncode: # Class handling label + - NRTRRecResizeImg: + image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 512 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - NRTRLabelEncode: # Class handling label + - NRTRRecResizeImg: + image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 1 + use_shared_memory: False diff --git a/deploy/cpp_infer/CMakeLists.txt b/deploy/cpp_infer/CMakeLists.txt index efb183c5b4..29a506846d 100644 --- a/deploy/cpp_infer/CMakeLists.txt +++ b/deploy/cpp_infer/CMakeLists.txt @@ -206,6 +206,10 @@ endif() set(DEPS ${DEPS} ${OpenCV_LIBS}) +include(ExternalProject) +include(external-cmake/auto-log.cmake) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/autolog/src/extern_Autolog/auto_log) + AUX_SOURCE_DIRECTORY(./src SRCS) add_executable(${DEMO_NAME} ${SRCS}) diff --git a/deploy/cpp_infer/external-cmake/auto-log.cmake b/deploy/cpp_infer/external-cmake/auto-log.cmake new file mode 100644 index 0000000000..dfa56188e8 --- /dev/null +++ b/deploy/cpp_infer/external-cmake/auto-log.cmake @@ -0,0 +1,14 @@ +find_package(Git REQUIRED) +message("${CMAKE_BUILD_TYPE}") + +set(AUTOLOG_REPOSITORY https://github.com/LDOUBLEV/AutoLog.git) +SET(AUTOLOG_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/install/Autolog) + +ExternalProject_Add( + extern_Autolog + PREFIX autolog + GIT_REPOSITORY ${AUTOLOG_REPOSITORY} + GIT_TAG main + DOWNLOAD_NO_EXTRACT True + INSTALL_COMMAND cmake -E echo "Skipping install step." +) diff --git a/deploy/cpp_infer/include/ocr_cls.h b/deploy/cpp_infer/include/ocr_cls.h index a43c800534..742e1f8bb0 100644 --- a/deploy/cpp_infer/include/ocr_cls.h +++ b/deploy/cpp_infer/include/ocr_cls.h @@ -42,7 +42,7 @@ class Classifier { const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, const bool &use_mkldnn, const double &cls_thresh, - const bool &use_tensorrt, const bool &use_fp16) { + const bool &use_tensorrt, const std::string &precision) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -51,7 +51,7 @@ class Classifier { this->cls_thresh = cls_thresh; this->use_tensorrt_ = use_tensorrt; - this->use_fp16_ = use_fp16; + this->precision_ = precision; LoadModel(model_dir); } @@ -75,7 +75,7 @@ class Classifier { std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; bool is_scale_ = true; bool use_tensorrt_ = false; - bool use_fp16_ = false; + std::string precision_ = "fp32"; // pre-process ClsResizeImg resize_op_; Normalize normalize_op_; diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index 18318c9c4e..e5a31ed8e5 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -46,7 +46,7 @@ class DBDetector { const double &det_db_box_thresh, const double &det_db_unclip_ratio, const bool &use_polygon_score, const bool &visualize, - const bool &use_tensorrt, const bool &use_fp16) { + const bool &use_tensorrt, const std::string &precision) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -62,7 +62,7 @@ class DBDetector { this->visualize_ = visualize; this->use_tensorrt_ = use_tensorrt; - this->use_fp16_ = use_fp16; + this->precision_ = precision; LoadModel(model_dir); } @@ -71,7 +71,7 @@ class DBDetector { void LoadModel(const std::string &model_dir); // Run predictor - void Run(cv::Mat &img, std::vector>> &boxes); + void Run(cv::Mat &img, std::vector>> &boxes, std::vector *times); private: std::shared_ptr predictor_; @@ -91,7 +91,7 @@ class DBDetector { bool visualize_ = true; bool use_tensorrt_ = false; - bool use_fp16_ = false; + std::string precision_ = "fp32"; std::vector mean_ = {0.485f, 0.456f, 0.406f}; std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 25f55ae26a..d585112b05 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -44,14 +44,14 @@ class CRNNRecognizer { const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, const bool &use_mkldnn, const string &label_path, - const bool &use_tensorrt, const bool &use_fp16) { + const bool &use_tensorrt, const std::string &precision) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; this->use_tensorrt_ = use_tensorrt; - this->use_fp16_ = use_fp16; + this->precision_ = precision; this->label_list_ = Utility::ReadDict(label_path); this->label_list_.insert(this->label_list_.begin(), @@ -64,7 +64,7 @@ class CRNNRecognizer { // Load Paddle inference model void LoadModel(const std::string &model_dir); - void Run(cv::Mat &img); + void Run(cv::Mat &img, std::vector *times); private: std::shared_ptr predictor_; @@ -81,7 +81,7 @@ class CRNNRecognizer { std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; bool is_scale_ = true; bool use_tensorrt_ = false; - bool use_fp16_ = false; + std::string precision_ = "fp32"; // pre-process CrnnResizeImg resize_op_; Normalize normalize_op_; @@ -90,9 +90,6 @@ class CRNNRecognizer { // post-process PostProcessor post_processor_; - cv::Mat GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box); - }; // class CrnnRecognizer } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index 6e8173e007..678187d3fa 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -47,6 +47,9 @@ class Utility { static void GetAllFiles(const char *dir_name, std::vector &all_inputs); + + static cv::Mat GetRotateCropImage(const cv::Mat &srcimage, + std::vector> box); }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 830f032d46..382be79706 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -38,10 +39,12 @@ DEFINE_bool(use_gpu, false, "Infering with GPU or CPU."); DEFINE_int32(gpu_id, 0, "Device id of GPU to execute."); DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU."); -DEFINE_int32(cpu_math_library_num_threads, 10, "Num of threads with CPU."); -DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU."); +DEFINE_int32(cpu_threads, 10, "Num of threads with CPU."); +DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU."); DEFINE_bool(use_tensorrt, false, "Whether use tensorrt."); -DEFINE_bool(use_fp16, false, "Whether use fp16 when use tensorrt."); +DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8"); +DEFINE_bool(benchmark, true, "Whether use benchmark."); +DEFINE_string(save_log_path, "./log_output/", "Save benchmark log path."); // detection related DEFINE_string(image_dir, "", "Dir of input image."); DEFINE_string(det_model_dir, "", "Path of det inference model."); @@ -57,6 +60,7 @@ DEFINE_string(cls_model_dir, "", "Path of cls inference model."); DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh."); // recognition related DEFINE_string(rec_model_dir, "", "Path of rec inference model."); +DEFINE_int32(rec_batch_num, 1, "rec_batch_num."); DEFINE_string(char_list_file, "../../ppocr/utils/ppocr_keys_v1.txt", "Path of dictionary."); @@ -76,88 +80,15 @@ static bool PathExists(const std::string& path){ } -cv::Mat GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box) { - cv::Mat image; - srcimage.copyTo(image); - std::vector> points = box; - - int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; - int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; - int left = int(*std::min_element(x_collect, x_collect + 4)); - int right = int(*std::max_element(x_collect, x_collect + 4)); - int top = int(*std::min_element(y_collect, y_collect + 4)); - int bottom = int(*std::max_element(y_collect, y_collect + 4)); - - cv::Mat img_crop; - image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); - - for (int i = 0; i < points.size(); i++) { - points[i][0] -= left; - points[i][1] -= top; - } - - int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) + - pow(points[0][1] - points[1][1], 2))); - int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) + - pow(points[0][1] - points[3][1], 2))); - - cv::Point2f pts_std[4]; - pts_std[0] = cv::Point2f(0., 0.); - pts_std[1] = cv::Point2f(img_crop_width, 0.); - pts_std[2] = cv::Point2f(img_crop_width, img_crop_height); - pts_std[3] = cv::Point2f(0.f, img_crop_height); - - cv::Point2f pointsf[4]; - pointsf[0] = cv::Point2f(points[0][0], points[0][1]); - pointsf[1] = cv::Point2f(points[1][0], points[1][1]); - pointsf[2] = cv::Point2f(points[2][0], points[2][1]); - pointsf[3] = cv::Point2f(points[3][0], points[3][1]); - - cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); - - cv::Mat dst_img; - cv::warpPerspective(img_crop, dst_img, M, - cv::Size(img_crop_width, img_crop_height), - cv::BORDER_REPLICATE); - - if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) { - cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth()); - cv::transpose(dst_img, srcCopy); - cv::flip(srcCopy, srcCopy, 0); - return srcCopy; - } else { - return dst_img; - } -} - - -int main_det(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) { - std::cout << "Usage[det]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - exit(1); - } - if (!PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +int main_det(std::vector cv_all_img_names) { + std::vector time_info = {0, 0, 0}; DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, - FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, - FLAGS_use_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh, + FLAGS_gpu_mem, FLAGS_cpu_threads, + FLAGS_enable_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio, FLAGS_use_polygon_score, FLAGS_visualize, - FLAGS_use_tensorrt, FLAGS_use_fp16); - - auto start = std::chrono::system_clock::now(); - + FLAGS_use_tensorrt, FLAGS_precision); + for (int i = 0; i < cv_all_img_names.size(); ++i) { LOG(INFO) << "The predict img: " << cv_all_img_names[i]; @@ -167,46 +98,38 @@ int main_det(int argc, char **argv) { exit(1); } std::vector>> boxes; + std::vector det_times; - det.Run(srcimg, boxes); - - auto end = std::chrono::system_clock::now(); - auto duration = - std::chrono::duration_cast(end - start); - std::cout << "Cost " - << double(duration.count()) * - std::chrono::microseconds::period::num / - std::chrono::microseconds::period::den - << "s" << std::endl; + det.Run(srcimg, boxes, &det_times); + + time_info[0] += det_times[0]; + time_info[1] += det_times[1]; + time_info[2] += det_times[2]; } + if (FLAGS_benchmark) { + AutoLogger autolog("ocr_det", + FLAGS_use_gpu, + FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, + FLAGS_cpu_threads, + 1, + "dynamic", + FLAGS_precision, + time_info, + cv_all_img_names.size()); + autolog.report(); + } return 0; } -int main_rec(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { - std::cout << "Usage[rec]: ./ppocr --rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - exit(1); - } - if (!PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +int main_rec(std::vector cv_all_img_names) { + std::vector time_info = {0, 0, 0}; CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, - FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, - FLAGS_use_mkldnn, FLAGS_char_list_file, - FLAGS_use_tensorrt, FLAGS_use_fp16); - - auto start = std::chrono::system_clock::now(); + FLAGS_gpu_mem, FLAGS_cpu_threads, + FLAGS_enable_mkldnn, FLAGS_char_list_file, + FLAGS_use_tensorrt, FLAGS_precision); for (int i = 0; i < cv_all_img_names.size(); ++i) { LOG(INFO) << "The predict img: " << cv_all_img_names[i]; @@ -217,65 +140,38 @@ int main_rec(int argc, char **argv) { exit(1); } - rec.Run(srcimg); + std::vector rec_times; + rec.Run(srcimg, &rec_times); - auto end = std::chrono::system_clock::now(); - auto duration = - std::chrono::duration_cast(end - start); - std::cout << "Cost " - << double(duration.count()) * - std::chrono::microseconds::period::num / - std::chrono::microseconds::period::den - << "s" << std::endl; + time_info[0] += rec_times[0]; + time_info[1] += rec_times[1]; + time_info[2] += rec_times[2]; } return 0; } -int main_system(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) || - (FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) { - std::cout << "Usage[system without angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " - << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - std::cout << "Usage[system with angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " - << "--use_angle_cls=true " - << "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ " - << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - exit(1); - } - if (!PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +int main_system(std::vector cv_all_img_names) { DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, - FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, - FLAGS_use_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh, + FLAGS_gpu_mem, FLAGS_cpu_threads, + FLAGS_enable_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio, FLAGS_use_polygon_score, FLAGS_visualize, - FLAGS_use_tensorrt, FLAGS_use_fp16); + FLAGS_use_tensorrt, FLAGS_precision); Classifier *cls = nullptr; if (FLAGS_use_angle_cls) { cls = new Classifier(FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, - FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, - FLAGS_use_mkldnn, FLAGS_cls_thresh, - FLAGS_use_tensorrt, FLAGS_use_fp16); + FLAGS_gpu_mem, FLAGS_cpu_threads, + FLAGS_enable_mkldnn, FLAGS_cls_thresh, + FLAGS_use_tensorrt, FLAGS_precision); } CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, - FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, - FLAGS_use_mkldnn, FLAGS_char_list_file, - FLAGS_use_tensorrt, FLAGS_use_fp16); + FLAGS_gpu_mem, FLAGS_cpu_threads, + FLAGS_enable_mkldnn, FLAGS_char_list_file, + FLAGS_use_tensorrt, FLAGS_precision); auto start = std::chrono::system_clock::now(); @@ -288,17 +184,19 @@ int main_system(int argc, char **argv) { exit(1); } std::vector>> boxes; - - det.Run(srcimg, boxes); + std::vector det_times; + std::vector rec_times; + + det.Run(srcimg, boxes, &det_times); cv::Mat crop_img; for (int j = 0; j < boxes.size(); j++) { - crop_img = GetRotateCropImage(srcimg, boxes[j]); + crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]); if (cls != nullptr) { crop_img = cls->Run(crop_img); } - rec.Run(crop_img); + rec.Run(crop_img, &rec_times); } auto end = std::chrono::system_clock::now(); @@ -315,22 +213,70 @@ int main_system(int argc, char **argv) { } +void check_params(char* mode) { + if (strcmp(mode, "det")==0) { + if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[det]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } + if (strcmp(mode, "rec")==0) { + if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[rec]: ./ppocr --rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } + if (strcmp(mode, "system")==0) { + if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) || + (FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) { + std::cout << "Usage[system without angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + std::cout << "Usage[system with angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--use_angle_cls=true " + << "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ " + << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } + if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") { + cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; + exit(1); + } +} + + int main(int argc, char **argv) { - if (strcmp(argv[1], "det")!=0 && strcmp(argv[1], "rec")!=0 && strcmp(argv[1], "system")!=0) { - std::cout << "Please choose one mode of [det, rec, system] !" << std::endl; - return -1; - } - std::cout << "mode: " << argv[1] << endl; - - if (strcmp(argv[1], "det")==0) { - return main_det(argc, argv); - } - if (strcmp(argv[1], "rec")==0) { - return main_rec(argc, argv); - } - if (strcmp(argv[1], "system")==0) { - return main_system(argc, argv); - } + if (argc<=1 || (strcmp(argv[1], "det")!=0 && strcmp(argv[1], "rec")!=0 && strcmp(argv[1], "system")!=0)) { + std::cout << "Please choose one mode of [det, rec, system] !" << std::endl; + return -1; + } + std::cout << "mode: " << argv[1] << endl; + + // Parsing command-line + google::ParseCommandLineFlags(&argc, &argv, true); + check_params(argv[1]); + + if (!PathExists(FLAGS_image_dir)) { + std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; + exit(1); + } -// return 0; + std::vector cv_all_img_names; + cv::glob(FLAGS_image_dir, cv_all_img_names); + std::cout << "total images num: " << cv_all_img_names.size() << endl; + + if (strcmp(argv[1], "det")==0) { + return main_det(cv_all_img_names); + } + if (strcmp(argv[1], "rec")==0) { + return main_rec(cv_all_img_names); + } + if (strcmp(argv[1], "system")==0) { + return main_system(cv_all_img_names); + } + } diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp index 9199e082e5..3b04b6f824 100644 --- a/deploy/cpp_infer/src/ocr_cls.cpp +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -77,10 +77,16 @@ void Classifier::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } config.EnableTensorRtEngine( 1 << 20, 10, 3, - this->use_fp16_ ? paddle_infer::Config::Precision::kHalf - : paddle_infer::Config::Precision::kFloat32, + precision, false, false); } } else { diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 58dc4dce81..a69f5ca1bd 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -26,10 +26,16 @@ void DBDetector::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } config.EnableTensorRtEngine( 1 << 20, 10, 3, - this->use_fp16_ ? paddle_infer::Config::Precision::kHalf - : paddle_infer::Config::Precision::kFloat32, + precision, false, false); std::map> min_input_shape = { {"x", {1, 3, 50, 50}}, @@ -91,13 +97,16 @@ void DBDetector::LoadModel(const std::string &model_dir) { } void DBDetector::Run(cv::Mat &img, - std::vector>> &boxes) { + std::vector>> &boxes, + std::vector *times) { float ratio_h{}; float ratio_w{}; cv::Mat srcimg; cv::Mat resize_img; img.copyTo(srcimg); + + auto preprocess_start = std::chrono::steady_clock::now(); this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w, this->use_tensorrt_); @@ -106,14 +115,17 @@ void DBDetector::Run(cv::Mat &img, std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); this->permute_op_.Run(&resize_img, input.data()); - + auto preprocess_end = std::chrono::steady_clock::now(); + // Inference. auto input_names = this->predictor_->GetInputNames(); auto input_t = this->predictor_->GetInputHandle(input_names[0]); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + auto inference_start = std::chrono::steady_clock::now(); input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); - + std::vector out_data; auto output_names = this->predictor_->GetOutputNames(); auto output_t = this->predictor_->GetOutputHandle(output_names[0]); @@ -123,7 +135,9 @@ void DBDetector::Run(cv::Mat &img, out_data.resize(out_num); output_t->CopyToCpu(out_data.data()); - + auto inference_end = std::chrono::steady_clock::now(); + + auto postprocess_start = std::chrono::steady_clock::now(); int n2 = output_shape[2]; int n3 = output_shape[3]; int n = n2 * n3; @@ -151,7 +165,15 @@ void DBDetector::Run(cv::Mat &img, this->det_db_unclip_ratio_, this->use_polygon_score_); boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg); + auto postprocess_end = std::chrono::steady_clock::now(); std::cout << "Detected boxes num: " << boxes.size() << endl; + + std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; + times->push_back(double(preprocess_diff.count() * 1000)); + std::chrono::duration inference_diff = inference_end - inference_start; + times->push_back(double(inference_diff.count() * 1000)); + std::chrono::duration postprocess_diff = postprocess_end - postprocess_start; + times->push_back(double(postprocess_diff.count() * 1000)); //// visualization if (this->visualize_) { diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index c4a784f82c..b64dcea5ae 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -16,13 +16,13 @@ namespace PaddleOCR { -void CRNNRecognizer::Run(cv::Mat &img) { +void CRNNRecognizer::Run(cv::Mat &img, std::vector *times) { cv::Mat srcimg; img.copyTo(srcimg); cv::Mat resize_img; float wh_ratio = float(srcimg.cols) / float(srcimg.rows); - + auto preprocess_start = std::chrono::steady_clock::now(); this->resize_op_.Run(srcimg, resize_img, wh_ratio, this->use_tensorrt_); this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, @@ -31,11 +31,13 @@ void CRNNRecognizer::Run(cv::Mat &img) { std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); this->permute_op_.Run(&resize_img, input.data()); + auto preprocess_end = std::chrono::steady_clock::now(); // Inference. auto input_names = this->predictor_->GetInputNames(); auto input_t = this->predictor_->GetInputHandle(input_names[0]); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + auto inference_start = std::chrono::steady_clock::now(); input_t->CopyFromCpu(input.data()); this->predictor_->Run(); @@ -49,8 +51,10 @@ void CRNNRecognizer::Run(cv::Mat &img) { predict_batch.resize(out_num); output_t->CopyToCpu(predict_batch.data()); + auto inference_end = std::chrono::steady_clock::now(); // ctc decode + auto postprocess_start = std::chrono::steady_clock::now(); std::vector str_res; int argmax_idx; int last_index = 0; @@ -73,11 +77,19 @@ void CRNNRecognizer::Run(cv::Mat &img) { } last_index = argmax_idx; } + auto postprocess_end = std::chrono::steady_clock::now(); score /= count; for (int i = 0; i < str_res.size(); i++) { std::cout << str_res[i]; } std::cout << "\tscore: " << score << std::endl; + + std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; + times->push_back(double(preprocess_diff.count() * 1000)); + std::chrono::duration inference_diff = inference_end - inference_start; + times->push_back(double(inference_diff.count() * 1000)); + std::chrono::duration postprocess_diff = postprocess_end - postprocess_start; + times->push_back(double(postprocess_diff.count() * 1000)); } void CRNNRecognizer::LoadModel(const std::string &model_dir) { @@ -89,10 +101,16 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } config.EnableTensorRtEngine( 1 << 20, 10, 3, - this->use_fp16_ ? paddle_infer::Config::Precision::kHalf - : paddle_infer::Config::Precision::kFloat32, + precision, false, false); std::map> min_input_shape = { {"x", {1, 3, 32, 10}}}; @@ -126,59 +144,4 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { this->predictor_ = CreatePredictor(config); } -cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box) { - cv::Mat image; - srcimage.copyTo(image); - std::vector> points = box; - - int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; - int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; - int left = int(*std::min_element(x_collect, x_collect + 4)); - int right = int(*std::max_element(x_collect, x_collect + 4)); - int top = int(*std::min_element(y_collect, y_collect + 4)); - int bottom = int(*std::max_element(y_collect, y_collect + 4)); - - cv::Mat img_crop; - image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); - - for (int i = 0; i < points.size(); i++) { - points[i][0] -= left; - points[i][1] -= top; - } - - int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) + - pow(points[0][1] - points[1][1], 2))); - int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) + - pow(points[0][1] - points[3][1], 2))); - - cv::Point2f pts_std[4]; - pts_std[0] = cv::Point2f(0., 0.); - pts_std[1] = cv::Point2f(img_crop_width, 0.); - pts_std[2] = cv::Point2f(img_crop_width, img_crop_height); - pts_std[3] = cv::Point2f(0.f, img_crop_height); - - cv::Point2f pointsf[4]; - pointsf[0] = cv::Point2f(points[0][0], points[0][1]); - pointsf[1] = cv::Point2f(points[1][0], points[1][1]); - pointsf[2] = cv::Point2f(points[2][0], points[2][1]); - pointsf[3] = cv::Point2f(points[3][0], points[3][1]); - - cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); - - cv::Mat dst_img; - cv::warpPerspective(img_crop, dst_img, M, - cv::Size(img_crop_width, img_crop_height), - cv::BORDER_REPLICATE); - - if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) { - cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth()); - cv::transpose(dst_img, srcCopy); - cv::flip(srcCopy, srcCopy, 0); - return srcCopy; - } else { - return dst_img; - } -} - } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index 2cd84f7e8d..dba445b747 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -92,4 +92,59 @@ void Utility::GetAllFiles(const char *dir_name, } } +cv::Mat Utility::GetRotateCropImage(const cv::Mat &srcimage, + std::vector> box) { + cv::Mat image; + srcimage.copyTo(image); + std::vector> points = box; + + int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; + int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + + cv::Mat img_crop; + image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); + + for (int i = 0; i < points.size(); i++) { + points[i][0] -= left; + points[i][1] -= top; + } + + int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) + + pow(points[0][1] - points[1][1], 2))); + int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) + + pow(points[0][1] - points[3][1], 2))); + + cv::Point2f pts_std[4]; + pts_std[0] = cv::Point2f(0., 0.); + pts_std[1] = cv::Point2f(img_crop_width, 0.); + pts_std[2] = cv::Point2f(img_crop_width, img_crop_height); + pts_std[3] = cv::Point2f(0.f, img_crop_height); + + cv::Point2f pointsf[4]; + pointsf[0] = cv::Point2f(points[0][0], points[0][1]); + pointsf[1] = cv::Point2f(points[1][0], points[1][1]); + pointsf[2] = cv::Point2f(points[2][0], points[2][1]); + pointsf[3] = cv::Point2f(points[3][0], points[3][1]); + + cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); + + cv::Mat dst_img; + cv::warpPerspective(img_crop, dst_img, M, + cv::Size(img_crop_width, img_crop_height), + cv::BORDER_REPLICATE); + + if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) { + cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth()); + cv::transpose(dst_img, srcCopy); + cv::flip(srcCopy, srcCopy, 0); + return srcCopy; + } else { + return dst_img; + } +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index bd2b964972..f80ddd9fb4 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -75,7 +75,7 @@ def main(config, device, logger, vdl_writer): model = build_model(config['Architecture']) flops = paddle.flops(model, [1, 3, 640, 640]) - logger.info(f"FLOPs before pruning: {flops}") + logger.info("FLOPs before pruning: {}".format(flops)) from paddleslim.dygraph import FPGMFilterPruner model.train() @@ -106,8 +106,8 @@ def main(config, device, logger, vdl_writer): def eval_fn(): metric = program.eval(model, valid_dataloader, post_process_class, - eval_class) - logger.info(f"metric['hmean']: {metric['hmean']}") + eval_class, False) + logger.info("metric['hmean']: {}".format(metric['hmean'])) return metric['hmean'] params_sensitive = pruner.sensitive( @@ -123,16 +123,17 @@ def eval_fn(): # calculate pruned params's ratio params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) for key in params_sensitive.keys(): - logger.info(f"{key}, {params_sensitive[key]}") + logger.info("{}, {}".format(key, params_sensitive[key])) + + #params_sensitive = {} + #for param in model.parameters(): + # if 'transpose' not in param.name and 'linear' not in param.name: + # params_sensitive[param.name] = 0.1 plan = pruner.prune_vars(params_sensitive, [0]) - for param in model.parameters(): - if ("weights" in param.name and "conv" in param.name) or ( - "w_0" in param.name and "conv2d" in param.name): - logger.info(f"{param.name}: {param.shape}") flops = paddle.flops(model, [1, 3, 640, 640]) - logger.info(f"FLOPs after pruning: {flops}") + logger.info("FLOPs after pruning: {}".format(flops)) # start train diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index b6a365b35d..d465539166 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) - [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -59,6 +60,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | -|SAR|Resnet31| 87.1% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | +|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | +|SAR|Resnet31| 87.2% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index bcad6b8654..4773874966 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -187,11 +187,11 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs #### 2.1 数据增强 -PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`。 +PaddleOCR提供了多种数据增强方式,默认配置文件中已经添加了数据增广。 -默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse)。 +默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse)、TIA数据增广。 -训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) +训练过程中每种扰动方式以40%的概率被选择,具体代码实现请参考:[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py) *由于OpenCV的兼容性问题,扰动操作暂时只支持Linux* @@ -217,6 +217,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | | rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index f201589abb..0572815323 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) - [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -61,6 +62,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| -|SAR|Resnet31| 87.1% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | +|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | +|SAR|Resnet31| 87.2% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 65030b4a16..6d2e0947e4 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -179,11 +179,11 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs #### 2.1 Data Augmentation -PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, please set `distort: true` in the configuration file. +PaddleOCR provides a variety of data augmentation methods. All the augmentation methods are enabled by default. -The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse. +The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, TIA augmentation. -Each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to: [img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) +Each disturbance method is selected with a 40% probability during the training process. For specific code implementation, please refer to: [rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py) #### 2.2 Training @@ -209,6 +209,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | | rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder | diff --git a/doc/joinus.PNG b/doc/joinus.PNG index 1228ce0a4d..3be14c603e 100644 Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ diff --git a/doc/table/1.png b/doc/table/1.png index 47df618ab1..faff6e3178 100644 Binary files a/doc/table/1.png and b/doc/table/1.png differ diff --git a/doc/table/table.jpg b/doc/table/table.jpg index 3daa619e52..95fdf84d92 100644 Binary files a/doc/table/table.jpg and b/doc/table/table.jpg differ diff --git a/paddleocr.py b/paddleocr.py index c52737f55b..45c1a40dbb 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -127,7 +127,7 @@ } SUPPORT_DET_MODEL = ['DB'] -VERSION = '2.2' +VERSION = '2.2.0.1' SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index e860c5a698..0bb3d50648 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -49,14 +49,12 @@ def term_mp(sig_num, frame): os.killpg(pgid, signal.SIGKILL) -signal.signal(signal.SIGINT, term_mp) -signal.signal(signal.SIGTERM, term_mp) - - def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'] + support_dict = [ + 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet' + ] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) @@ -96,4 +94,8 @@ def build_dataloader(config, mode, device, logger, seed=None): return_list=True, use_shared_memory=use_shared_memory) + # support exit using ctrl+c + signal.signal(signal.SIGINT, term_mp) + signal.signal(signal.SIGTERM, term_mp) + return data_loader diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 6f0492e16a..8bfc175083 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SARRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 56da029b32..643ec70503 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -161,6 +161,34 @@ def encode(self, text): return text_list +class NRTRLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='EN_symbol', + use_space_char=False, + **kwargs): + + super(NRTRLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + data['length'] = np.array(len(text)) + text.insert(0, 2) + text.append(3) + text = text + [0] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + def add_special_char(self, dict_character): + dict_character = ['blank','','',''] + dict_character + return dict_character + class CTCLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 2535b4420c..aa3acd1d12 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -57,6 +57,38 @@ def __call__(self, data): return data +class NRTRDecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data['image'] = img + return data + class NormalizeImage(object): """ normalize image such as substract mean, divide std """ diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index d968f43744..51f5855ac3 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -16,7 +16,7 @@ import cv2 import numpy as np import random - +from PIL import Image from .text_image_aug import tia_perspective, tia_stretch, tia_distort @@ -43,6 +43,25 @@ def __call__(self, data): return data +class NRTRRecResizeImg(object): + def __init__(self, image_shape, resize_type, **kwargs): + self.image_shape = image_shape + self.resize_type = resize_type + + def __call__(self, data): + img = data['image'] + if self.resize_type == 'PIL': + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + img = np.array(img) + if self.resize_type == 'OpenCV': + img = cv2.resize(img, self.image_shape) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data['image'] = norm_img.astype(np.float32) / 128. - 1. + return data + + class RecResizeImg(object): def __init__(self, image_shape, diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index d731185b3e..0484542f0b 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,8 +25,8 @@ from .rec_ctc_loss import CTCLoss from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss +from .rec_nrtr_loss import NRTRLoss from .rec_sar_loss import SARLoss - # cls loss from .cls_loss import ClsLoss @@ -45,8 +45,9 @@ def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'SARLoss' + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss' ] + config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/cls_loss.py b/ppocr/losses/cls_loss.py index ecca5d2e17..abc5e5b72c 100755 --- a/ppocr/losses/cls_loss.py +++ b/ppocr/losses/cls_loss.py @@ -25,6 +25,6 @@ def __init__(self, **kwargs): self.loss_func = nn.CrossEntropyLoss(reduction='mean') def forward(self, predicts, batch): - label = batch[1] + label = batch[1].astype("int64") loss = self.loss_func(input=predicts, label=label) return {'loss': loss} diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py new file mode 100644 index 0000000000..41714dd2a3 --- /dev/null +++ b/ppocr/losses/rec_nrtr_loss.py @@ -0,0 +1,30 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F + + +class NRTRLoss(nn.Layer): + def __init__(self, smoothing=True, **kwargs): + super(NRTRLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) + self.smoothing = smoothing + + def forward(self, pred, batch): + pred = pred.reshape([-1, pred.shape[2]]) + max_len = batch[2].max() + tgt = batch[1][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) + if self.smoothing: + eps = 0.1 + n_class = pred.shape[1] + one_hot = F.one_hot(tgt, pred.shape[1]) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, axis=1) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype='int64')) + loss = -(one_hot * log_prb).sum(axis=1) + loss = loss.masked_select(non_pad_mask).mean() + else: + loss = self.loss_func(pred, tgt) + return {'loss': loss} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 66c084d771..3e82fe756b 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -57,3 +57,4 @@ def reset(self): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 + diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index dbd18070b3..c498d9862a 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from paddle import nn from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index ce1d6bcf28..50438893c7 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,9 +26,10 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance + from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 support_dict = [ - "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", "ResNet31" + 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_nrtr_mtb.py b/ppocr/modeling/backbones/rec_nrtr_mtb.py new file mode 100644 index 0000000000..04b5c9bb5f --- /dev/null +++ b/ppocr/modeling/backbones/rec_nrtr_mtb.py @@ -0,0 +1,46 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import nn + + +class MTB(nn.Layer): + def __init__(self, cnn_num, in_channels): + super(MTB, self).__init__() + self.block = nn.Sequential() + self.out_channels = in_channels + self.cnn_num = cnn_num + if self.cnn_num == 2: + for i in range(self.cnn_num): + self.block.add_sublayer( + 'conv_{}'.format(i), + nn.Conv2D( + in_channels=in_channels + if i == 0 else 32 * (2**(i - 1)), + out_channels=32 * (2**i), + kernel_size=3, + stride=2, + padding=1)) + self.block.add_sublayer('relu_{}'.format(i), nn.ReLU()) + self.block.add_sublayer('bn_{}'.format(i), + nn.BatchNorm2D(32 * (2**i))) + + def forward(self, images): + x = self.block(images) + if self.cnn_num == 2: + # (b, w, h, c) + x = x.transpose([0, 3, 2, 1]) + x_shape = x.shape + x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]]) + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 8414f2adfd..80311f92a6 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -26,13 +26,15 @@ def build_head(config): from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead + from .rec_nrtr_head import Transformer from .rec_sar_head import SARHead # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'TableAttentionHead', 'SARHead'] + 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead' + ] #table head from .table_att_head import TableAttentionHead diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py new file mode 100755 index 0000000000..651d4f577d --- /dev/null +++ b/ppocr/modeling/heads/multiheadAttention.py @@ -0,0 +1,178 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import Linear +from paddle.nn.initializer import XavierUniform as xavier_uniform_ +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + + +class MultiheadAttention(nn.Layer): + """Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + """ + + def __init__(self, + embed_dim, + num_heads, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + self._reset_parameters() + self.conv1 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + + def _reset_parameters(self): + xavier_uniform_(self.out_proj.weight) + + def forward(self, + query, + key, + value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1])) + assert list(attn_output_weights. + shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape( + [bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + y = paddle.where(key == 0., key, y) + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape( + [bsz * self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), + axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 + else attn_output_weights.dtype) + attn_output_weights = F.dropout( + attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output. + shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0, 2]).reshape( + [tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape( + [bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum( + axis=1) / self.num_heads + else: + attn_output_weights = None + return attn_output, attn_output_weights + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + value = value.transpose([1, 2, 0]) #(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv3(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res diff --git a/ppocr/modeling/heads/rec_nrtr_head.py b/ppocr/modeling/heads/rec_nrtr_head.py new file mode 100644 index 0000000000..05dba677b4 --- /dev/null +++ b/ppocr/modeling/heads/rec_nrtr_head.py @@ -0,0 +1,844 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import copy +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import LayerList +from paddle.nn.initializer import XavierNormal as xavier_uniform_ +from paddle.nn import Dropout, Linear, LayerNorm, Conv2D +import numpy as np +from ppocr.modeling.heads.multiheadAttention import MultiheadAttention +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + + +class Transformer(nn.Layer): + """A transformer model. User is able to modify the attributes as needed. The architechture + is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, + Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and + Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information + Processing Systems, pages 6000-6010. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + + """ + + def __init__(self, + d_model=512, + nhead=8, + num_encoder_layers=6, + beam_size=0, + num_decoder_layers=6, + dim_feedforward=1024, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1, + custom_encoder=None, + custom_decoder=None, + in_channels=0, + out_channels=0, + dst_vocab_size=99, + scale_embedding=True): + super(Transformer, self).__init__() + self.embedding = Embeddings( + d_model=d_model, + vocab=dst_vocab_size, + padding_idx=0, + scale_embedding=scale_embedding) + self.positional_encoding = PositionalEncoding( + dropout=residual_dropout_rate, + dim=d_model, ) + if custom_encoder is not None: + self.encoder = custom_encoder + else: + if num_encoder_layers > 0: + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) + self.encoder = TransformerEncoder(encoder_layer, + num_encoder_layers) + else: + self.encoder = None + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers) + + self._reset_parameters() + self.beam_size = beam_size + self.d_model = d_model + self.nhead = nhead + self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False) + w0 = np.random.normal(0.0, d_model**-0.5, + (d_model, dst_vocab_size)).astype(np.float32) + self.tgt_word_prj.weight.set_value(w0) + self.apply(self._init_weights) + + def _init_weights(self, m): + + if isinstance(m, nn.Conv2D): + xavier_normal_(m.weight) + if m.bias is not None: + zeros_(m.bias) + + def forward_train(self, src, tgt): + tgt = tgt[:, :-1] + + tgt_key_padding_mask = self.generate_padding_mask(tgt) + tgt = self.embedding(tgt).transpose([1, 0, 2]) + tgt = self.positional_encoding(tgt) + tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0]) + + if self.encoder is not None: + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + output = output.transpose([1, 0, 2]) + logit = self.tgt_word_prj(output) + return logit + + def forward(self, src, targets=None): + """Take in and process masked source/target sequences. + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + Shape: + - src: :math:`(S, N, E)`. + - tgt: :math:`(T, N, E)`. + Examples: + >>> output = transformer_model(src, tgt) + """ + + if self.training: + max_len = targets[1].max() + tgt = targets[0][:, :2 + max_len] + return self.forward_train(src, tgt) + else: + if self.beam_size > 0: + return self.forward_beam(src) + else: + return self.forward_test(src) + + def forward_test(self, src): + bs = src.shape[0] + if self.encoder is not None: + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) + for len_dec_seq in range(1, 25): + src_enc = memory.clone() + tgt_key_padding_mask = self.generate_padding_mask(dec_seq) + dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2]) + dec_seq_embed = self.positional_encoding(dec_seq_embed) + tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[ + 0]) + output = self.decoder( + dec_seq_embed, + src_enc, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + dec_output = output.transpose([1, 0, 2]) + + dec_output = dec_output[:, + -1, :] # Pick the last step: (bh * bm) * d_h + word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) + word_prob = word_prob.reshape([1, bs, -1]) + preds_idx = word_prob.argmax(axis=2) + + if paddle.equal_all( + preds_idx[-1], + paddle.full( + preds_idx[-1].shape, 3, dtype='int64')): + break + + preds_prob = word_prob.max(axis=2) + dec_seq = paddle.concat( + [dec_seq, preds_idx.reshape([-1, 1])], axis=1) + + return dec_seq + + def forward_beam(self, images): + ''' Translation work in one batch ''' + + def get_inst_idx_to_tensor_position_map(inst_idx_list): + ''' Indicate the position of an instance in a tensor. ''' + return { + inst_idx: tensor_position + for tensor_position, inst_idx in enumerate(inst_idx_list) + } + + def collect_active_part(beamed_tensor, curr_active_inst_idx, + n_prev_active_inst, n_bm): + ''' Collect tensor parts associated to active instances. ''' + + _, *d_hs = beamed_tensor.shape + n_curr_active_inst = len(curr_active_inst_idx) + new_shape = (n_curr_active_inst * n_bm, *d_hs) + + beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1]) + beamed_tensor = beamed_tensor.index_select( + paddle.to_tensor(curr_active_inst_idx), axis=0) + beamed_tensor = beamed_tensor.reshape([*new_shape]) + + return beamed_tensor + + def collate_active_info(src_enc, inst_idx_to_position_map, + active_inst_idx_list): + # Sentences which are still active are collected, + # so the decoder will not run on completed sentences. + + n_prev_active_inst = len(inst_idx_to_position_map) + active_inst_idx = [ + inst_idx_to_position_map[k] for k in active_inst_idx_list + ] + active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64') + active_src_enc = collect_active_part( + src_enc.transpose([1, 0, 2]), active_inst_idx, + n_prev_active_inst, n_bm).transpose([1, 0, 2]) + active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) + return active_src_enc, active_inst_idx_to_position_map + + def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output, + inst_idx_to_position_map, n_bm, + memory_key_padding_mask): + ''' Decode and update beam status, and then return active beam idx ''' + + def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): + dec_partial_seq = [ + b.get_current_state() for b in inst_dec_beams if not b.done + ] + dec_partial_seq = paddle.stack(dec_partial_seq) + + dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq]) + return dec_partial_seq + + def prepare_beam_memory_key_padding_mask( + inst_dec_beams, memory_key_padding_mask, n_bm): + keep = [] + for idx in (memory_key_padding_mask): + if not inst_dec_beams[idx].done: + keep.append(idx) + memory_key_padding_mask = memory_key_padding_mask[ + paddle.to_tensor(keep)] + len_s = memory_key_padding_mask.shape[-1] + n_inst = memory_key_padding_mask.shape[0] + memory_key_padding_mask = paddle.concat( + [memory_key_padding_mask for i in range(n_bm)], axis=1) + memory_key_padding_mask = memory_key_padding_mask.reshape( + [n_inst * n_bm, len_s]) #repeat(1, n_bm) + return memory_key_padding_mask + + def predict_word(dec_seq, enc_output, n_active_inst, n_bm, + memory_key_padding_mask): + tgt_key_padding_mask = self.generate_padding_mask(dec_seq) + dec_seq = self.embedding(dec_seq).transpose([1, 0, 2]) + dec_seq = self.positional_encoding(dec_seq) + tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[ + 0]) + dec_output = self.decoder( + dec_seq, + enc_output, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ).transpose([1, 0, 2]) + dec_output = dec_output[:, + -1, :] # Pick the last step: (bh * bm) * d_h + word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) + word_prob = word_prob.reshape([n_active_inst, n_bm, -1]) + return word_prob + + def collect_active_inst_idx_list(inst_beams, word_prob, + inst_idx_to_position_map): + active_inst_idx_list = [] + for inst_idx, inst_position in inst_idx_to_position_map.items(): + is_inst_complete = inst_beams[inst_idx].advance(word_prob[ + inst_position]) + if not is_inst_complete: + active_inst_idx_list += [inst_idx] + + return active_inst_idx_list + + n_active_inst = len(inst_idx_to_position_map) + dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) + memory_key_padding_mask = None + word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, + memory_key_padding_mask) + # Update the beam with predicted word prob information and collect incomplete instances + active_inst_idx_list = collect_active_inst_idx_list( + inst_dec_beams, word_prob, inst_idx_to_position_map) + return active_inst_idx_list + + def collect_hypothesis_and_scores(inst_dec_beams, n_best): + all_hyp, all_scores = [], [] + for inst_idx in range(len(inst_dec_beams)): + scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() + all_scores += [scores[:n_best]] + hyps = [ + inst_dec_beams[inst_idx].get_hypothesis(i) + for i in tail_idxs[:n_best] + ] + all_hyp += [hyps] + return all_hyp, all_scores + + with paddle.no_grad(): + #-- Encode + + if self.encoder is not None: + src = self.positional_encoding(images.transpose([1, 0, 2])) + src_enc = self.encoder(src).transpose([1, 0, 2]) + else: + src_enc = images.squeeze(2).transpose([0, 2, 1]) + + #-- Repeat data for beam search + n_bm = self.beam_size + n_inst, len_s, d_h = src_enc.shape + src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) + src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( + [1, 0, 2]) + #-- Prepare beams + inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)] + + #-- Bookkeeping for active or not + active_inst_idx_list = list(range(n_inst)) + inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) + #-- Decode + for len_dec_seq in range(1, 25): + src_enc_copy = src_enc.clone() + active_inst_idx_list = beam_decode_step( + inst_dec_beams, len_dec_seq, src_enc_copy, + inst_idx_to_position_map, n_bm, None) + if not active_inst_idx_list: + break # all instances have finished their path to + src_enc, inst_idx_to_position_map = collate_active_info( + src_enc_copy, inst_idx_to_position_map, + active_inst_idx_list) + batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, + 1) + result_hyp = [] + for bs_hyp in batch_hyp: + bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0])) + result_hyp.append(bs_hyp_pad) + return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64) + + def generate_square_subsequent_mask(self, sz): + """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = paddle.zeros([sz, sz], dtype='float32') + mask_inf = paddle.triu( + paddle.full( + shape=[sz, sz], dtype='float32', fill_value='-inf'), + diagonal=1) + mask = mask + mask_inf + return mask + + def generate_padding_mask(self, x): + padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype)) + return padding_mask + + def _reset_parameters(self): + """Initiate parameters in the transformer model.""" + + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(nn.Layer): + """TransformerEncoder is a stack of N encoder layers + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + """ + + def __init__(self, encoder_layer, num_layers): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, src): + """Pass the input through the endocder layers in turn. + Args: + src: the sequnce to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + """ + output = src + + for i in range(self.num_layers): + output = self.layers[i](output, + src_mask=None, + src_key_padding_mask=None) + + return output + + +class TransformerDecoder(nn.Layer): + """TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + """ + + def __init__(self, decoder_layer, num_layers): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + """Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + """ + output = tgt + for i in range(self.num_layers): + output = self.layers[i]( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + + return output + + +class TransformerEncoderLayer(nn.Layer): + """TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """Pass the input through the endocder layer. + Args: + src: the sequnce to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + """ + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + + src = src.transpose([1, 2, 0]) + src = paddle.unsqueeze(src, 2) + src2 = self.conv2(F.relu(self.conv1(src))) + src2 = paddle.squeeze(src2, 2) + src2 = src2.transpose([2, 0, 1]) + src = paddle.squeeze(src, 2) + src = src.transpose([2, 0, 1]) + + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Layer): + """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + self.multihead_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + self.dropout3 = Dropout(residual_dropout_rate) + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + """ + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # default + tgt = tgt.transpose([1, 2, 0]) + tgt = paddle.unsqueeze(tgt, 2) + tgt2 = self.conv2(F.relu(self.conv1(tgt))) + tgt2 = paddle.squeeze(tgt2, 2) + tgt2 = tgt2.transpose([2, 0, 1]) + tgt = paddle.squeeze(tgt, 2) + tgt = tgt.transpose([2, 0, 1]) + + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +def _get_clones(module, N): + return LayerList([copy.deepcopy(module) for i in range(N)]) + + +class PositionalEncoding(nn.Layer): + """Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0) + pe = pe.transpose([1, 0, 2]) + self.register_buffer('pe', pe) + + def forward(self, x): + """Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +class PositionalEncoding_2d(nn.Layer): + """Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding_2d, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0).transpose([1, 0, 2]) + self.register_buffer('pe', pe) + + self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear1 = nn.Linear(dim, dim) + self.linear1.weight.data.fill_(1.) + self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear2 = nn.Linear(dim, dim) + self.linear2.weight.data.fill_(1.) + + def forward(self, x): + """Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + w_pe = self.pe[:x.shape[-1], :] + w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) + w_pe = w_pe * w1 + w_pe = w_pe.transpose([1, 2, 0]) + w_pe = w_pe.unsqueeze(2) + + h_pe = self.pe[:x.shape[-2], :] + w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) + h_pe = h_pe * w2 + h_pe = h_pe.transpose([1, 2, 0]) + h_pe = h_pe.unsqueeze(3) + + x = x + w_pe + h_pe + x = x.reshape( + [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose( + [2, 0, 1]) + + return self.dropout(x) + + +class Embeddings(nn.Layer): + def __init__(self, d_model, vocab, padding_idx, scale_embedding): + super(Embeddings, self).__init__() + self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) + w0 = np.random.normal(0.0, d_model**-0.5, + (vocab, d_model)).astype(np.float32) + self.embedding.weight.set_value(w0) + self.d_model = d_model + self.scale_embedding = scale_embedding + + def forward(self, x): + if self.scale_embedding: + x = self.embedding(x) + return x * math.sqrt(self.d_model) + return self.embedding(x) + + +class Beam(): + ''' Beam search ''' + + def __init__(self, size, device=False): + + self.size = size + self._done = False + # The score for each translation on the beam. + self.scores = paddle.zeros((size, ), dtype=paddle.float32) + self.all_scores = [] + # The backpointers at each time-step. + self.prev_ks = [] + # The outputs at each time-step. + self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)] + self.next_ys[0][0] = 2 + + def get_current_state(self): + "Get the outputs for the current timestep." + return self.get_tentative_hypothesis() + + def get_current_origin(self): + "Get the backpointers for the current timestep." + return self.prev_ks[-1] + + @property + def done(self): + return self._done + + def advance(self, word_prob): + "Update beam status and check if finished or not." + num_words = word_prob.shape[1] + + # Sum the previous scores. + if len(self.prev_ks) > 0: + beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) + else: + beam_lk = word_prob[0] + + flat_beam_lk = beam_lk.reshape([-1]) + best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, + True) # 1st sort + self.all_scores.append(self.scores) + self.scores = best_scores + # bestScoresId is flattened as a (beam x word) array, + # so we need to calculate which word and beam each score came from + prev_k = best_scores_id // num_words + self.prev_ks.append(prev_k) + self.next_ys.append(best_scores_id - prev_k * num_words) + # End condition is when top-of-beam is EOS. + if self.next_ys[-1][0] == 3: + self._done = True + self.all_scores.append(self.scores) + + return self._done + + def sort_scores(self): + "Sort the scores." + return self.scores, paddle.to_tensor( + [i for i in range(self.scores.shape[0])], dtype='int32') + + def get_the_best_score_and_idx(self): + "Get the score of the best in the beam." + scores, ids = self.sort_scores() + return scores[1], ids[1] + + def get_tentative_hypothesis(self): + "Get the decoded sequence for the current timestep." + if len(self.next_ys) == 1: + dec_seq = self.next_ys[0].unsqueeze(1) + else: + _, keys = self.sort_scores() + hyps = [self.get_hypothesis(k) for k in keys] + hyps = [[2] + h for h in hyps] + dec_seq = paddle.to_tensor(hyps, dtype='int64') + return dec_seq + + def get_hypothesis(self, k): + """ Walk back to construct the full hypothesis. """ + hyp = [] + for j in range(len(self.prev_ks) - 1, -1, -1): + hyp.append(self.next_ys[j + 1][k]) + k = self.prev_ks[j][k] + return list(map(lambda x: x.item(), hyp[::-1])) diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 86f9ede45c..77081abeb6 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,18 +24,17 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess -from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \ TableLabelDecode, SARLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess - def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', - 'DistillationDBPostProcess', 'SARLabelDecode' + 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 9e9ddd8f95..83d7b21540 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -157,6 +157,69 @@ def __call__(self, preds, label=None, *args, **kwargs): return output +class NRTRLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='EN_symbol', + use_space_char=True, + **kwargs): + super(NRTRLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if preds.dtype == paddle.int64: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + if preds[0][0]==2: + preds_idx = preds[:,1:] + else: + preds_idx = preds + + text = self.decode(preds_idx) + if label is None: + return text + label = self.decode(label[:,1:]) + else: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:,1:]) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank','','',''] + dict_character + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == 3: # end + break + try: + char_list.append(self.character[int(text_index[batch_idx][idx])]) + except: + continue + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text.lower(), np.mean(conf_list))) + return result_list + + + class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ @@ -194,8 +257,7 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ batch_idx][idx]: continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) + char_list.append(self.character[int(text_index[batch_idx][idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: diff --git a/ppstructure/README.md b/ppstructure/README.md index 8e1642cc75..a00d12d8ed 100644 --- a/ppstructure/README.md +++ b/ppstructure/README.md @@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/ # CPU python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple -# For more,refer[Installation](https://www.paddlepaddle.org.cn/install/quick)。 ``` +For more,refer [Installation](https://www.paddlepaddle.org.cn/install/quick) . - **(2) Install Layout-Parser** ```bash -pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl +pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl ``` ### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure) @@ -180,10 +180,10 @@ OCR and table recognition model |model name|description|model size|download| | --- | --- | --- | --- | -|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) | -|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) | -|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) | -|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) | -|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) | +|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) | +|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) | +|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) | +|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) | +|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` . diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md index c8acac5900..821a6c3e36 100644 --- a/ppstructure/README_ch.md +++ b/ppstructure/README_ch.md @@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/ # CPU安装 python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple -# 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 ``` +更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 - **(2) 安装 Layout-Parser** ```bash -pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl +pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl ``` ### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure) @@ -179,10 +179,10 @@ OCR和表格识别模型 |模型名称|模型简介|推理模型大小|下载地址| | --- | --- | --- | --- | -|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) | -|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) | -|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) | -|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) | -|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) | +|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) | +|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) | +|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) | +|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) | +|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | 如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。 diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md index a8d10b79e5..67c4d8e26d 100644 --- a/ppstructure/table/README.md +++ b/ppstructure/table/README.md @@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar cd .. # run -python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table +python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table ``` Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`. diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md index 2ded403c37..e580debaeb 100644 --- a/ppstructure/table/README_ch.md +++ b/ppstructure/table/README_ch.md @@ -43,7 +43,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar cd .. # 执行预测 -python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table +python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table ``` 运行完成后,每张图片的excel表格会保存到output字段指定的目录下 diff --git a/requirements.txt b/requirements.txt index 351d409092..2c7baa8516 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,7 @@ tqdm numpy visualdl python-Levenshtein -opencv-contrib-python==4.4.0.46 \ No newline at end of file +opencv-contrib-python==4.4.0.46 +lxml +premailer +openpyxl \ No newline at end of file diff --git a/tests/ocr_det_params.txt b/tests/ocr_det_params.txt index 6aff66c6aa..73b12cec9c 100644 --- a/tests/ocr_det_params.txt +++ b/tests/ocr_det_params.txt @@ -4,7 +4,7 @@ python:python3.7 gpu_list:0|0,1 Global.use_gpu:True|True Global.auto_cast:null -Global.epoch_num:lite_train_infer=2|whole_train_infer=300 +Global.epoch_num:lite_train_infer=1|whole_train_infer=300 Global.save_model_dir:./output/ Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4 Global.pretrained_model:null @@ -15,7 +15,7 @@ null:null trainer:norm_train|pact_train norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained pact_train:deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o -fpgm_train:null +fpgm_train:deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy distill_train:null null:null null:null @@ -29,7 +29,7 @@ Global.save_inference_dir:./output/ Global.pretrained_model: norm_export:tools/export_model.py -c configs/det/det_mv3_db.yml -o quant_export:deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o -fpgm_export:deploy/slim/prune/export_prune_model.py +fpgm_export:deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o distill_export:null export1:null export2:null @@ -49,4 +49,19 @@ inference:tools/infer/predict_det.py --save_log_path:null --benchmark:True null:null +===========================cpp_infer_params=========================== +use_opencv:True +infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/ +infer_quant:False +inference:./deploy/cpp_infer/build/ppocr det +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1 +--use_tensorrt:False|True +--precision:fp32|fp16 +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ +--save_log_path:null +--benchmark:True diff --git a/tests/ocr_det_server_params.txt b/tests/ocr_det_server_params.txt new file mode 100644 index 0000000000..0835cfffa6 --- /dev/null +++ b/tests/ocr_det_server_params.txt @@ -0,0 +1,52 @@ +===========================train_params=========================== +model_name:ocr_server_det +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_infer=2|whole_train_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ +null:null +## +trainer:norm_train|pact_train +norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o Global.pretrained_model="" +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c configs/det/det_mv3_db.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.pretrained_model: +norm_export:tools/export_model.py -c configs/det/det_r50_vd_db.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +infer_model:./inference/ch_ppocr_server_v2.0_det_infer/ +infer_export:null +infer_quant:False +inference:tools/infer/predict_det.py +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1 +--use_tensorrt:False|True +--precision:fp32|fp16|int8 +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ +--save_log_path:null +--benchmark:True +null:null + diff --git a/tests/prepare.sh b/tests/prepare.sh index d27a051cb0..5da74d949f 100644 --- a/tests/prepare.sh +++ b/tests/prepare.sh @@ -1,6 +1,6 @@ #!/bin/bash FILENAME=$1 -# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer'] +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer'] MODE=$2 dataline=$(cat ${FILENAME}) @@ -34,11 +34,14 @@ MODE=$2 if [ ${MODE} = "lite_train_infer" ];then # pretrain lite train data wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar + cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../ rm -rf ./train_data/icdar2015 rm -rf ./train_data/ic15_data wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos - + wget -nc -P ./deploy/slim/prune https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/sen.pickle + cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar ln -s ./icdar2015_lite ./icdar2015 cd ../ @@ -58,13 +61,17 @@ elif [ ${MODE} = "whole_infer" ];then cd ./train_data/ && tar xf icdar2015_infer.tar && tar xf ic15_data.tar ln -s ./icdar2015_infer ./icdar2015 cd ../ -else +elif [ ${MODE} = "infer" ] || [ ${MODE} = "cpp_infer" ];then if [ ${model_name} = "ocr_det" ]; then eval_model_name="ch_ppocr_mobile_v2.0_det_infer" rm -rf ./train_data/icdar2015 - wget -nc -P ./train_data https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../ + elif [ ${model_name} = "ocr_server_det" ]; then + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar + cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../ else rm -rf ./train_data/ic15_data eval_model_name="ch_ppocr_mobile_v2.0_rec_infer" @@ -74,3 +81,72 @@ else fi fi +if [ ${MODE} = "cpp_infer" ];then + cd deploy/cpp_infer + use_opencv=$(func_parser_value "${lines[52]}") + if [ ${use_opencv} = "True" ]; then + echo "################### build opencv ###################" + rm -rf 3.4.7.tar.gz opencv-3.4.7/ + wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz + tar -xf 3.4.7.tar.gz + + cd opencv-3.4.7/ + install_path=$(pwd)/opencv-3.4.7/opencv3 + + rm -rf build + mkdir build + cd build + + cmake .. \ + -DCMAKE_INSTALL_PREFIX=${install_path} \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_IPP=OFF \ + -DBUILD_IPP_IW=OFF \ + -DWITH_LAPACK=OFF \ + -DWITH_EIGEN=OFF \ + -DCMAKE_INSTALL_LIBDIR=lib64 \ + -DWITH_ZLIB=ON \ + -DBUILD_ZLIB=ON \ + -DWITH_JPEG=ON \ + -DBUILD_JPEG=ON \ + -DWITH_PNG=ON \ + -DBUILD_PNG=ON \ + -DWITH_TIFF=ON \ + -DBUILD_TIFF=ON + + make -j + make install + cd ../ + echo "################### build opencv finished ###################" + fi + + + echo "################### build PaddleOCR demo ####################" + if [ ${use_opencv} = "True" ]; then + OPENCV_DIR=$(pwd)/opencv-3.4.7/opencv3/ + else + OPENCV_DIR='' + fi + LIB_DIR=$(pwd)/Paddle/build/paddle_inference_install_dir/ + CUDA_LIB_DIR=$(dirname `find /usr -name libcudart.so`) + CUDNN_LIB_DIR=$(dirname `find /usr -name libcudnn.so`) + + BUILD_DIR=build + rm -rf ${BUILD_DIR} + mkdir ${BUILD_DIR} + cd ${BUILD_DIR} + cmake .. \ + -DPADDLE_LIB=${LIB_DIR} \ + -DWITH_MKL=ON \ + -DWITH_GPU=OFF \ + -DWITH_STATIC_LIB=OFF \ + -DWITH_TENSORRT=OFF \ + -DOPENCV_DIR=${OPENCV_DIR} \ + -DCUDNN_LIB=${CUDNN_LIB_DIR} \ + -DCUDA_LIB=${CUDA_LIB_DIR} \ + -DTENSORRT_DIR=${TENSORRT_DIR} \ + + make -j + echo "################### build PaddleOCR demo finished ###################" +fi \ No newline at end of file diff --git a/tests/readme.md b/tests/readme.md new file mode 100644 index 0000000000..1c5e0faee9 --- /dev/null +++ b/tests/readme.md @@ -0,0 +1,58 @@ + +# 介绍 + +test.sh和params.txt文件配合使用,完成OCR轻量检测和识别模型从训练到预测的流程测试。 + +# 安装依赖 +- 安装PaddlePaddle >= 2.0 +- 安装PaddleOCR依赖 + ``` + pip3 install -r ../requirements.txt + ``` +- 安装autolog + ``` + git clone https://github.com/LDOUBLEV/AutoLog + cd AutoLog + pip3 install -r requirements.txt + python3 setup.py bdist_wheel + pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl + cd ../ + ``` + +# 目录介绍 + +```bash +tests/ +├── ocr_det_params.txt # 测试OCR检测模型的参数配置文件 +├── ocr_rec_params.txt # 测试OCR识别模型的参数配置文件 +└── prepare.sh # 完成test.sh运行所需要的数据和模型下载 +└── test.sh # 根据 +``` + +# 使用方法 +test.sh包含四种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是: +- 模式1 lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度; +``` +bash test/prepare.sh ./tests/ocr_det_params.txt 'lite_train_infer' +bash tests/test.sh ./tests/ocr_det_params.txt 'lite_train_infer' +``` +- 模式2 whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理; +``` +bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_infer' +bash tests/test.sh ./tests/ocr_det_params.txt 'whole_infer' +``` + +- 模式3 infer 不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度; +``` +bash tests/prepare.sh ./tests/ocr_det_params.txt 'infer' +用法1: +bash tests/test.sh ./tests/ocr_det_params.txt 'infer' +用法2: 指定GPU卡预测,第三个传入参数为GPU卡号 +bash tests/test.sh ./tests/ocr_det_params.txt 'infer' '1' +``` + +模式4: whole_train_infer , CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度 +``` +bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_train_infer' +bash tests/test.sh ./tests/ocr_det_params.txt 'whole_train_infer' +``` diff --git a/tests/test.sh b/tests/test.sh index 9888e0faab..484d557353 100644 --- a/tests/test.sh +++ b/tests/test.sh @@ -1,6 +1,6 @@ #!/bin/bash FILENAME=$1 -# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer'] +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer'] MODE=$2 dataline=$(cat ${FILENAME}) @@ -145,6 +145,33 @@ benchmark_value=$(func_parser_value "${lines[49]}") infer_key1=$(func_parser_key "${lines[50]}") infer_value1=$(func_parser_value "${lines[50]}") +if [ ${MODE} = "cpp_infer" ]; then + # parser cpp inference model + cpp_infer_model_dir_list=$(func_parser_value "${lines[53]}") + cpp_infer_is_quant=$(func_parser_value "${lines[54]}") + # parser cpp inference + inference_cmd=$(func_parser_value "${lines[55]}") + cpp_use_gpu_key=$(func_parser_key "${lines[56]}") + cpp_use_gpu_list=$(func_parser_value "${lines[56]}") + cpp_use_mkldnn_key=$(func_parser_key "${lines[57]}") + cpp_use_mkldnn_list=$(func_parser_value "${lines[57]}") + cpp_cpu_threads_key=$(func_parser_key "${lines[58]}") + cpp_cpu_threads_list=$(func_parser_value "${lines[58]}") + cpp_batch_size_key=$(func_parser_key "${lines[59]}") + cpp_batch_size_list=$(func_parser_value "${lines[59]}") + cpp_use_trt_key=$(func_parser_key "${lines[60]}") + cpp_use_trt_list=$(func_parser_value "${lines[60]}") + cpp_precision_key=$(func_parser_key "${lines[61]}") + cpp_precision_list=$(func_parser_value "${lines[61]}") + cpp_infer_model_key=$(func_parser_key "${lines[62]}") + cpp_image_dir_key=$(func_parser_key "${lines[63]}") + cpp_infer_img_dir=$(func_parser_value "${lines[63]}") + cpp_save_log_key=$(func_parser_key "${lines[64]}") + cpp_benchmark_key=$(func_parser_key "${lines[65]}") + cpp_benchmark_value=$(func_parser_value "${lines[65]}") +fi + + LOG_PATH="./tests/output" mkdir -p ${LOG_PATH} status_log="${LOG_PATH}/results.log" @@ -218,6 +245,71 @@ function func_inference(){ done } +function func_cpp_inference(){ + IFS='|' + _script=$1 + _model_dir=$2 + _log_path=$3 + _img_dir=$4 + _flag_quant=$5 + # inference + for use_gpu in ${cpp_use_gpu_list[*]}; do + if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then + for use_mkldnn in ${cpp_use_mkldnn_list[*]}; do + if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then + continue + fi + for threads in ${cpp_cpu_threads_list[*]}; do + for batch_size in ${cpp_batch_size_list[*]}; do + _save_log_path="${_log_path}/cpp_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}") + set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}") + set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}") + set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}") + command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${cpp_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + done + done + done + elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then + for use_trt in ${cpp_use_trt_list[*]}; do + for precision in ${cpp_precision_list[*]}; do + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue + fi + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue + fi + if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then + continue + fi + for batch_size in ${cpp_batch_size_list[*]}; do + _save_log_path="${_log_path}/cpp_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}") + set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}") + set_tensorrt=$(func_set_params "${cpp_use_trt_key}" "${use_trt}") + set_precision=$(func_set_params "${cpp_precision_key}" "${precision}") + set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}") + command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + + done + done + done + else + echo "Does not support hardware other than CPU and GPU Currently!" + fi + done +} + if [ ${MODE} = "infer" ]; then GPUID=$3 if [ ${#GPUID} -le 0 ];then @@ -252,6 +344,25 @@ if [ ${MODE} = "infer" ]; then Count=$(($Count + 1)) done +elif [ ${MODE} = "cpp_infer" ]; then + GPUID=$3 + if [ ${#GPUID} -le 0 ];then + env=" " + else + env="export CUDA_VISIBLE_DEVICES=${GPUID}" + fi + # set CUDA_VISIBLE_DEVICES + eval $env + export Count=0 + IFS="|" + infer_quant_flag=(${cpp_infer_is_quant}) + for infer_model in ${cpp_infer_model_dir_list[*]}; do + #run inference + is_quant=${infer_quant_flag[Count]} + func_cpp_inference "${inference_cmd}" "${infer_model}" "${LOG_PATH}" "${cpp_infer_img_dir}" ${is_quant} + Count=$(($Count + 1)) + done + else IFS="|" export Count=0 diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 3de00d83a8..5c75e0c480 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -101,6 +101,7 @@ def __init__(self, args): if args.benchmark: import auto_log pid = os.getpid() + gpu_id = utility.get_infer_gpuid() self.autolog = auto_log.AutoLogger( model_name="det", model_precision=args.precision, @@ -110,7 +111,7 @@ def __init__(self, args): inference_config=self.config, pids=pid, process_name=None, - gpu_ids=0, + gpu_ids=gpu_id if args.use_gpu else None, time_keys=[ 'preprocess_time', 'inference_time', 'postprocess_time' ], diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py index cd6c2005a7..5029d60593 100755 --- a/tools/infer/predict_e2e.py +++ b/tools/infer/predict_e2e.py @@ -74,7 +74,7 @@ def __init__(self, args): self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor( + self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor( args, 'e2e', logger) # paddle.jit.load(args.det_model_dir) # self.predictor.eval() diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index bb4a317064..7401a16ee6 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -68,6 +68,7 @@ def __init__(self, args): if args.benchmark: import auto_log pid = os.getpid() + gpu_id = utility.get_infer_gpuid() self.autolog = auto_log.AutoLogger( model_name="rec", model_precision=args.precision, @@ -77,7 +78,7 @@ def __init__(self, args): inference_config=self.config, pids=pid, process_name=None, - gpu_ids=0 if args.use_gpu else None, + gpu_ids=gpu_id if args.use_gpu else None, time_keys=[ 'preprocess_time', 'inference_time', 'postprocess_time' ], @@ -87,8 +88,8 @@ def __init__(self, args): def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] - if self.character_type == "ch": - imgW = int((32 * max_wh_ratio)) + max_wh_ratio = max(max_wh_ratio, imgW / imgH) + imgW = int((32 * max_wh_ratio)) h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: @@ -277,7 +278,7 @@ def main(args): if args.warmup: img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8) for i in range(2): - res = text_recognizer([img]) + res = text_recognizer([img] * int(args.rec_batch_num)) for image_file in image_file_list: img, flag = check_and_read_gif(image_file) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 1c82280099..707328f284 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -159,6 +159,11 @@ def create_predictor(args, mode, logger): precision = inference.PrecisionType.Float32 if args.use_gpu: + gpu_id = get_infer_gpuid() + if gpu_id is None: + raise ValueError( + "Not found GPU in current device. Please check your device or set args.use_gpu as False" + ) config.enable_use_gpu(args.gpu_mem, 0) if args.use_tensorrt: config.enable_tensorrt_engine( @@ -280,6 +285,20 @@ def create_predictor(args, mode, logger): return predictor, input_tensor, output_tensors, config +def get_infer_gpuid(): + cmd = "nvidia-smi" + res = os.popen(cmd).readlines() + if len(res) == 0: + return None + cmd = "env | grep CUDA_VISIBLE_DEVICES" + env_cuda = os.popen(cmd).readlines() + if len(env_cuda) == 0: + return 0 + else: + gpu_id = env_cuda[0].strip().split("=")[1] + return int(gpu_id[0]) + + def draw_e2e_res(dt_boxes, strs, img_path): src_im = cv2.imread(img_path) for box, str in zip(dt_boxes, strs): diff --git a/tools/program.py b/tools/program.py index cb6f8a8b75..2d3114f66d 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,10 +186,11 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" + use_nrtr = config['Architecture']['algorithm'] == "NRTR" use_sar = config['Architecture']['algorithm'] == 'SAR' - try: + try: model_type = config['Architecture']['model_type'] - except: + except: model_type = None if 'start_epoch' in best_model_dict: @@ -214,7 +215,7 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table' or use_sar: + if use_srn or model_type == 'table' or use_nrtr or use_sar: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -401,7 +402,11 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', +<<<<<<< HEAD 'CLS', 'PGNet', 'Distillation', 'TableAttn', 'SAR' +======= + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn' +>>>>>>> 63ed5fcab30801626ecf55a89f5dc9faf79a16d2 ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'