Skip to content

Commit

Permalink
[Refactor] Refactor onnx export and all demo post process. (#360)
Browse files Browse the repository at this point in the history
* [Refactor] Refactor onnx export and all demo post process.

* ncnn demo

* onnx output tmp

* refactor ncnn cpp demo

* refactor cpp demo

* support concat output and onnxsim

* refactor mnn demo

* add multi-backend demo

* add multi-backend demo

* refactor android demo

* fix detach
  • Loading branch information
RangiLyu authored Dec 22, 2021
1 parent 161bda1 commit 2ca268a
Show file tree
Hide file tree
Showing 19 changed files with 397 additions and 1,808 deletions.
1 change: 1 addition & 0 deletions demo/demo_multi_backend_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Working in progress
62 changes: 42 additions & 20 deletions demo_android_ncnn/app/src/main/cpp/NanoDet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,32 @@ int activation_function_softmax(const _Tp* src, _Tp* dst, int length)
return 0;
}

static void generate_grid_center_priors(const int input_height, const int input_width, std::vector<int>& strides, std::vector<CenterPrior>& center_priors)
{
for (int i = 0; i < (int)strides.size(); i++)
{
int stride = strides[i];
int feat_w = ceil((float)input_width / stride);
int feat_h = ceil((float)input_height / stride);
for (int y = 0; y < feat_h; y++)
{
for (int x = 0; x < feat_w; x++)
{
CenterPrior ct;
ct.x = x;
ct.y = y;
ct.stride = stride;
center_priors.push_back(ct);
}
}
}
}

NanoDet::NanoDet(AAssetManager *mgr, const char *param, const char *bin, bool useGPU) {
this->Net = new ncnn::Net();
// opt 需要在加载前设置
hasGPU = ncnn::get_gpu_count() > 0;
this->Net->opt.use_vulkan_compute = false; //hasGPU && useGPU; // gpu
this->Net->opt.use_fp16_arithmetic = true; // fp16运算加速
this->Net->opt.use_fp16_arithmetic = true;
this->Net->opt.use_fp16_packed = true;
this->Net->opt.use_fp16_storage = true;
this->Net->load_param(mgr, param);
Expand Down Expand Up @@ -82,19 +102,19 @@ std::vector<BoxInfo> NanoDet::detect(JNIEnv *env, jobject image, float score_thr
ex.set_num_threads(4);
hasGPU = ncnn::get_gpu_count() > 0;
//ex.set_vulkan_compute(hasGPU);
ex.input("input.1", input);
ex.input("data", input);
std::vector<std::vector<BoxInfo>> results;
results.resize(this->num_class);

for (const auto& head_info : this->heads_info)
{
ncnn::Mat dis_pred;
ncnn::Mat cls_pred;
ex.extract(head_info.dis_layer.c_str(), dis_pred);
ex.extract(head_info.cls_layer.c_str(), cls_pred);
ncnn::Mat out;
ex.extract("output", out);
// printf("%d %d %d \n", out.w, out.h, out.c);

this->decode_infer(cls_pred, dis_pred, head_info.stride, score_threshold, results, width_ratio, height_ratio);
}
// generate center priors in format of (x, y, stride)
std::vector<CenterPrior> center_priors;
generate_grid_center_priors(this->input_size[0], this->input_size[1], this->strides, center_priors);

this->decode_infer(out, center_priors, score_threshold, results, width_ratio, height_ratio);

std::vector<BoxInfo> dets;
for (int i = 0; i < (int)results.size(); i++)
Expand All @@ -110,17 +130,19 @@ std::vector<BoxInfo> NanoDet::detect(JNIEnv *env, jobject image, float score_thr
}


void NanoDet::decode_infer(ncnn::Mat& cls_pred, ncnn::Mat& dis_pred, int stride, float threshold, std::vector<std::vector<BoxInfo>>& results, float width_ratio, float height_ratio)
void NanoDet::decode_infer(ncnn::Mat& feats, std::vector<CenterPrior>& center_priors, float threshold, std::vector<std::vector<BoxInfo>>& results, float width_ratio, float height_ratio)
{
int feature_h = this->input_size / stride;
int feature_w = this->input_size / stride;
const int num_points = center_priors.size();
//printf("num_points:%d\n", num_points);

//cv::Mat debug_heatmap = cv::Mat(feature_h, feature_w, CV_8UC3);
for (int idx = 0; idx < feature_h * feature_w; idx++)
for (int idx = 0; idx < num_points; idx++)
{
const float* scores = cls_pred.row(idx);
int row = idx / feature_w;
int col = idx % feature_w;
const int ct_x = center_priors[idx].x;
const int ct_y = center_priors[idx].y;
const int stride = center_priors[idx].stride;

const float* scores = feats.row(idx);
float score = 0;
int cur_label = 0;
for (int label = 0; label < this->num_class; label++)
Expand All @@ -134,8 +156,8 @@ void NanoDet::decode_infer(ncnn::Mat& cls_pred, ncnn::Mat& dis_pred, int stride,
if (score > threshold)
{
//std::cout << "label:" << cur_label << " score:" << score << std::endl;
const float* bbox_pred = dis_pred.row(idx);
results[cur_label].push_back(this->disPred2Bbox(bbox_pred, cur_label, score, col, row, stride, width_ratio, height_ratio));
const float* bbox_pred = feats.row(idx) + this->num_class;
results[cur_label].push_back(this->disPred2Bbox(bbox_pred, cur_label, score, ct_x, ct_y, stride, width_ratio, height_ratio));
//debug_heatmap.at<cv::Vec3b>(row, col)[0] = 255;
//cv::imshow("debug", debug_heatmap);
}
Expand Down
26 changes: 15 additions & 11 deletions demo_android_ncnn/app/src/main/cpp/NanoDet.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,20 @@
#include "net.h"
#include "YoloV5.h"

typedef struct HeadInfo
typedef struct HeadInfo_
{
std::string cls_layer;
std::string dis_layer;
int stride;
} HeadInfo;

typedef struct CenterPrior_
{
int x;
int y;
int stride;
} CenterPrior;


class NanoDet{
public:
Expand All @@ -35,21 +42,18 @@ class NanoDet{
"hair drier", "toothbrush"};
private:
void preprocess(JNIEnv *env, jobject image, ncnn::Mat& in);
void decode_infer(ncnn::Mat& cls_pred, ncnn::Mat& dis_pred, int stride, float threshold, std::vector<std::vector<BoxInfo>>& results, float width_ratio, float height_ratio);
void decode_infer(ncnn::Mat& feats, std::vector<CenterPrior>& center_priors, float threshold, std::vector<std::vector<BoxInfo>>& results, float width_ratio, float height_ratio);
BoxInfo disPred2Bbox(const float*& dfl_det, int label, float score, int x, int y, int stride, float width_ratio, float height_ratio);

static void nms(std::vector<BoxInfo>& result, float nms_threshold);

ncnn::Net *Net;
int input_size = 320;
int num_class = 80;
int reg_max = 7;
std::vector<HeadInfo> heads_info{
// cls_pred|dis_pred|stride
{"cls_pred_stride_8", "dis_pred_stride_8", 8},
{"cls_pred_stride_16", "dis_pred_stride_16", 16},
{"cls_pred_stride_32", "dis_pred_stride_32", 32},
};
// modify these parameters to the same with your config if you want to use your own model
int input_size[2] = {416, 416}; // input height and width
int num_class = 80; // number of classes. 80 for COCO
int reg_max = 7; // `reg_max` set in the training config. Default: 7.
std::vector<int> strides = { 8, 16, 32, 64 }; // strides of the multi-level feature.


public:
static NanoDet *detector;
Expand Down
20 changes: 15 additions & 5 deletions demo_mnn/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ int image_demo(NanoDet &detector, const char* imagepath)
std::vector<cv::String> filenames;
cv::glob(imagepath, filenames, false);

int height = detector.input_size[0];
int width = detector.input_size[1];

for (auto img_name : filenames)
{
cv::Mat image = cv::imread(img_name);
Expand All @@ -247,7 +250,7 @@ int image_demo(NanoDet &detector, const char* imagepath)
}
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(image, resized_img, cv::Size(320, 320), effect_roi);
resize_uniform(image, resized_img, cv::Size(width, height), effect_roi);
std::vector<BoxInfo> results;
detector.detect(resized_img, results);

Expand All @@ -267,13 +270,15 @@ int webcam_demo(NanoDet& detector, int cam_id)
{
cv::Mat image;
cv::VideoCapture cap(cam_id);
int height = detector.input_size[0];
int width = detector.input_size[1];

while (true)
{
cap >> image;
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(image, resized_img, cv::Size(320, 320), effect_roi);
resize_uniform(image, resized_img, cv::Size(width, height), effect_roi);
std::vector<BoxInfo> results;
detector.detect(resized_img, results);
draw_bboxes(image, results, effect_roi);
Expand All @@ -286,13 +291,15 @@ int video_demo(NanoDet& detector, const char* path)
{
cv::Mat image;
cv::VideoCapture cap(path);
int height = detector.input_size[0];
int width = detector.input_size[1];

while (true)
{
cap >> image;
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(image, resized_img, cv::Size(320, 320), effect_roi);
resize_uniform(image, resized_img, cv::Size(width, height), effect_roi);
std::vector<BoxInfo> results;
detector.detect(resized_img, results);
draw_bboxes(image, results, effect_roi);
Expand All @@ -309,7 +316,10 @@ int benchmark(NanoDet& detector)
double time_min = DBL_MAX;
double time_max = -DBL_MAX;
double time_avg = 0;
cv::Mat image(320, 320, CV_8UC3, cv::Scalar(1, 1, 1));

int height = detector.input_size[0];
int width = detector.input_size[1];
cv::Mat image(height, width, CV_8UC3, cv::Scalar(1, 1, 1));
for (int i = 0; i < warm_up + loop_num; i++)
{
auto start = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -340,7 +350,7 @@ int main(int argc, char** argv)
return -1;
}
// NanoDet detector = NanoDet("../model/nanodet-160.mnn", 160, 160, 4, 0.4, 0.3);
NanoDet detector = NanoDet("../model/nanodet-320.mnn", 320, 320, 4, 0.45, 0.3);
NanoDet detector = NanoDet("nanodet.mnn", 416, 416, 4, 0.45, 0.3);
int mode = atoi(argv[1]);
switch (mode)
{
Expand Down
93 changes: 55 additions & 38 deletions demo_mnn/nanodet_mnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,34 @@

using namespace std;

static void generate_grid_center_priors(const int input_height, const int input_width, std::vector<int>& strides, std::vector<CenterPrior>& center_priors)
{
for (int i = 0; i < (int)strides.size(); i++)
{
int stride = strides[i];
int feat_w = ceil((float)input_width / stride);
int feat_h = ceil((float)input_height / stride);
for (int y = 0; y < feat_h; y++)
{
for (int x = 0; x < feat_w; x++)
{
CenterPrior ct;
ct.x = x;
ct.y = y;
ct.stride = stride;
center_priors.push_back(ct);
}
}
}
}


NanoDet::NanoDet(const std::string &mnn_path,
int input_width, int input_length, int num_thread_,
float score_threshold_, float nms_threshold_)
{
num_thread = num_thread_;
in_w = input_width;
in_h = input_length;

score_threshold = score_threshold_;
nms_threshold = nms_threshold_;

Expand Down Expand Up @@ -41,14 +62,14 @@ int NanoDet::detect(cv::Mat &raw_image, std::vector<BoxInfo> &result_list)
image_h = raw_image.rows;
image_w = raw_image.cols;
cv::Mat image;
cv::resize(raw_image, image, cv::Size(in_w, in_h));
cv::resize(raw_image, image, cv::Size(input_size[1], input_size[0]));

NanoDet_interpreter->resizeTensor(input_tensor, {1, 3, in_h, in_w});
NanoDet_interpreter->resizeTensor(input_tensor, {1, 3, input_size[0], input_size[1]});
NanoDet_interpreter->resizeSession(NanoDet_session);
std::shared_ptr<MNN::CV::ImageProcess> pretreat(
MNN::CV::ImageProcess::create(MNN::CV::BGR, MNN::CV::BGR, mean_vals, 3,
norm_vals, 3));
pretreat->convert(image.data, in_w, in_h, image.step[0], input_tensor);
pretreat->convert(image.data, input_size[1], input_size[0], image.step[0], input_tensor);

auto start = chrono::steady_clock::now();

Expand All @@ -60,19 +81,15 @@ int NanoDet::detect(cv::Mat &raw_image, std::vector<BoxInfo> &result_list)
std::vector<std::vector<BoxInfo>> results;
results.resize(num_class);

for (const auto &head_info : heads_info)
{
MNN::Tensor *tensor_scores = NanoDet_interpreter->getSessionOutput(NanoDet_session, head_info.cls_layer.c_str());
MNN::Tensor *tensor_boxes = NanoDet_interpreter->getSessionOutput(NanoDet_session, head_info.dis_layer.c_str());

MNN::Tensor tensor_scores_host(tensor_scores, tensor_scores->getDimensionType());
tensor_scores->copyToHostTensor(&tensor_scores_host);
MNN::Tensor *tensor_preds = NanoDet_interpreter->getSessionOutput(NanoDet_session, output_name.c_str());

MNN::Tensor tensor_boxes_host(tensor_boxes, tensor_boxes->getDimensionType());
tensor_boxes->copyToHostTensor(&tensor_boxes_host);
MNN::Tensor tensor_preds_host(tensor_preds, tensor_preds->getDimensionType());
tensor_preds->copyToHostTensor(&tensor_preds_host);
// generate center priors in format of (x, y, stride)
std::vector<CenterPrior> center_priors;
generate_grid_center_priors(this->input_size[0], this->input_size[1], this->strides, center_priors);

decode_infer(&tensor_scores_host, &tensor_boxes_host, head_info.stride, score_threshold, results);
}
decode_infer(&tensor_preds_host, center_priors, score_threshold, results);

auto end = chrono::steady_clock::now();
chrono::duration<double> elapsed = end - start;
Expand All @@ -85,10 +102,10 @@ int NanoDet::detect(cv::Mat &raw_image, std::vector<BoxInfo> &result_list)

for (auto box : results[i])
{
box.x1 = box.x1 / in_w * image_w;
box.x2 = box.x2 / in_w * image_w;
box.y1 = box.y1 / in_h * image_h;
box.y2 = box.y2 / in_h * image_h;
box.x1 = box.x1 / input_size[1] * image_w;
box.x2 = box.x2 / input_size[1] * image_w;
box.y1 = box.y1 / input_size[0] * image_h;
box.y2 = box.y2 / input_size[0] * image_h;
result_list.push_back(box);
}
}
Expand All @@ -97,19 +114,22 @@ int NanoDet::detect(cv::Mat &raw_image, std::vector<BoxInfo> &result_list)
return 0;
}

void NanoDet::decode_infer(MNN::Tensor *cls_pred, MNN::Tensor *dis_pred, int stride, float threshold, std::vector<std::vector<BoxInfo>> &results)
void NanoDet::decode_infer(MNN::Tensor *pred, std::vector<CenterPrior>& center_priors, float threshold, std::vector<std::vector<BoxInfo>> &results)
{
int feature_h = in_h / stride;
int feature_w = in_w / stride;
const int num_points = center_priors.size();
const int num_channels = num_class + (reg_max + 1) * 4;
//printf("num_points:%d\n", num_points);

//cv::Mat debug_heatmap = cv::Mat(feature_h, feature_w, CV_8UC3);
for (int idx = 0; idx < feature_h * feature_w; idx++)
for (int idx = 0; idx < num_points; idx++)
{
// scores is a tensor with shape [feature_h * feature_w, num_class]
const float *scores = cls_pred->host<float>() + (idx * num_class);
const int ct_x = center_priors[idx].x;
const int ct_y = center_priors[idx].y;
const int stride = center_priors[idx].stride;

// preds is a tensor with shape [num_points, num_channels]
const float *scores = pred->host<float>() + (idx * num_channels);

int row = idx / feature_w;
int col = idx % feature_w;
float score = 0;
int cur_label = 0;
for (int label = 0; label < num_class; label++)
Expand All @@ -122,20 +142,17 @@ void NanoDet::decode_infer(MNN::Tensor *cls_pred, MNN::Tensor *dis_pred, int str
}
if (score > threshold)
{
//std::cout << "label:" << cur_label << " score:" << score << std::endl;
// bbox is a tensor with shape [feature_h * feature_w, 4_points * 8_distribution_bite]
const float *bbox_pred = dis_pred->host<float>() + (idx * 4 * (reg_max + 1));
results[cur_label].push_back(disPred2Bbox(bbox_pred, cur_label, score, col, row, stride));
//debug_heatmap.at<cv::Vec3b>(row, col)[0] = 255;
//cv::imshow("debug", debug_heatmap);
const float *bbox_pred = pred->host<float>() + idx * num_channels + num_class;
results[cur_label].push_back(disPred2Bbox(bbox_pred, cur_label, score, ct_x, ct_y, stride));

}
}
}

BoxInfo NanoDet::disPred2Bbox(const float *&dfl_det, int label, float score, int x, int y, int stride)
{
float ct_x = (x + 0.5) * stride;
float ct_y = (y + 0.5) * stride;
float ct_x = x * stride;
float ct_y = y * stride;
std::vector<float> dis_pred;
dis_pred.resize(4);
for (int i = 0; i < 4; i++)
Expand All @@ -154,8 +171,8 @@ BoxInfo NanoDet::disPred2Bbox(const float *&dfl_det, int label, float score, int
}
float xmin = (std::max)(ct_x - dis_pred[0], .0f);
float ymin = (std::max)(ct_y - dis_pred[1], .0f);
float xmax = (std::min)(ct_x + dis_pred[2], (float)in_w);
float ymax = (std::min)(ct_y + dis_pred[3], (float)in_h);
float xmax = (std::min)(ct_x + dis_pred[2], (float)input_size[1]);
float ymax = (std::min)(ct_y + dis_pred[3], (float)input_size[0]);

//std::cout << xmin << "," << ymin << "," << xmax << "," << xmax << "," << std::endl;
return BoxInfo{xmin, ymin, xmax, ymax, score, label};
Expand Down
Loading

0 comments on commit 2ca268a

Please sign in to comment.