Skip to content

Commit

Permalink
Fix demo libtorch error (#420)
Browse files Browse the repository at this point in the history
* fix bug

* fix demo_libtorch error

This line of code ("auto outputs = this->Net.forward({input}).toTuple();") caused a crash.

Maybe it's ok in old version.But it's not Tuple now.
  • Loading branch information
jedi007 committed Aug 26, 2022
1 parent b7600c8 commit 6797732
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 40 deletions.
2 changes: 1 addition & 1 deletion demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def main():
)
mkdir(local_rank, save_folder)
save_path = (
os.path.join(save_folder, args.path.split("/")[-1])
os.path.join(save_folder, args.path.replace("\\","/").split("/")[-1])
if args.demo == "video"
else os.path.join(save_folder, "camera.mp4")
)
Expand Down
11 changes: 6 additions & 5 deletions demo_libtorch/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ void draw_bboxes(const cv::Mat& bgr, const std::vector<BoxInfo>& bboxes, object_
cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(255, 255, 255));
}

cv::namedWindow("image", cv::WINDOW_AUTOSIZE);
cv::imshow("image", image);
}

Expand All @@ -238,7 +239,7 @@ int image_demo(NanoDet& detector, const char* imagepath)
}
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(image, resized_img, cv::Size(320, 320), effect_roi);
resize_uniform(image, resized_img, cv::Size(416, 416), effect_roi);
auto results = detector.detect(resized_img, 0.4, 0.5);
draw_bboxes(image, results, effect_roi);
cv::waitKey(0);
Expand All @@ -257,7 +258,7 @@ int webcam_demo(NanoDet& detector, int cam_id)
cap >> image;
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(image, resized_img, cv::Size(320, 320), effect_roi);
resize_uniform(image, resized_img, cv::Size(416, 416), effect_roi);
auto results = detector.detect(resized_img, 0.4, 0.5);
draw_bboxes(image, results, effect_roi);
cv::waitKey(1);
Expand All @@ -275,7 +276,7 @@ int video_demo(NanoDet& detector, const char* path)
cap >> image;
object_rect effect_roi;
cv::Mat resized_img;
resize_uniform(image, resized_img, cv::Size(320, 320), effect_roi);
resize_uniform(image, resized_img, cv::Size(416, 416), effect_roi);
auto results = detector.detect(resized_img, 0.4, 0.5);
draw_bboxes(image, results, effect_roi);
cv::waitKey(1);
Expand All @@ -291,7 +292,7 @@ int benchmark(NanoDet& detector)
double time_min = DBL_MAX;
double time_max = -DBL_MAX;
double time_avg = 0;
cv::Mat image(320, 320, CV_8UC3, cv::Scalar(1, 1, 1));
cv::Mat image(416, 416, CV_8UC3, cv::Scalar(1, 1, 1));

for (int i = 0; i < warm_up + loop_num; i++)
{
Expand Down Expand Up @@ -320,7 +321,7 @@ int main(int argc, char** argv)
return -1;
}
std::cout<<"start init model"<<std::endl;
auto detector = NanoDet("../model/nanodet_m.pt");
auto detector = NanoDet("../model/nanodet.torchscript.pth");
std::cout<<"success"<<std::endl;

int mode = atoi(argv[1]);
Expand Down
65 changes: 34 additions & 31 deletions demo_libtorch/nanodet_libtorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,15 @@ torch::Tensor NanoDet::preprocess(cv::Mat& image)
std::vector<BoxInfo> NanoDet::detect(cv::Mat image, float score_threshold, float nms_threshold)
{
auto input = preprocess(image);
auto outputs = this->Net.forward({input}).toTuple();
auto outputs = this->Net.forward({input}).toTensor();

auto cls_preds = outputs->elements()[0].toTensorVector();
auto box_preds = outputs->elements()[1].toTensorVector();
torch::Tensor cls_preds = outputs.index({ "...",torch::indexing::Slice(0,this->num_class_) });
torch::Tensor box_preds = outputs.index({ "...",torch::indexing::Slice(this->num_class_ , torch::indexing::None) });

std::vector<std::vector<BoxInfo>> results;
results.resize(this->num_class_);

for (int i = 0; i < (int)strides_.size(); i++)
{
this->decode_infer(cls_preds[i], box_preds[i], i, score_threshold, results);
}
this->decode_infer(cls_preds, box_preds, score_threshold, results);

std::vector<BoxInfo> dets;
for (int i = 0; i < (int)results.size(); i++)
Expand All @@ -94,36 +91,42 @@ std::vector<BoxInfo> NanoDet::detect(cv::Mat image, float score_threshold, float
return dets;
}

void NanoDet::decode_infer(torch::Tensor& cls_pred, torch::Tensor& dis_pred, int stage_idx, float threshold, std::vector<std::vector<BoxInfo>>& results)
void NanoDet::decode_infer(torch::Tensor& cls_pred, torch::Tensor& dis_pred, float threshold, std::vector<std::vector<BoxInfo>>& results)
{
int stride = this->strides_[stage_idx];
int feature_h = this->input_size_ / stride;
int feature_w = this->input_size_ / stride;
// cv::Mat debug_heatmap = cv::Mat::zeros(feature_h, feature_w, CV_8UC3);
for (int idx = 0; idx < feature_h * feature_w; idx++)
int total_idx = 0;
for (int stage_idx = 0; stage_idx < (int)strides_.size(); stage_idx++)
{
int row = idx / feature_w;
int col = idx % feature_w;
float score = -0.0f;
int cur_label = 0;
for (int label = 0; label < this->num_class_; label++)
int stride = this->strides_[stage_idx];
int feature_h = ceil(double(this->input_size_) / stride);
int feature_w = ceil(double(this->input_size_) / stride);
// cv::Mat debug_heatmap = cv::Mat::zeros(feature_h, feature_w, CV_8UC3);

for (int idx = total_idx; idx < feature_h * feature_w + total_idx; idx++)
{
float cur_score = cls_pred[0][idx][label].item<float>();
if ( cur_score > score)
int row = (idx - total_idx) / feature_w;
int col = (idx - total_idx) % feature_w;
float score = -0.0f;
int cur_label = 0;
for (int label = 0; label < this->num_class_; label++)
{
score = cur_score;
cur_label = label;
float cur_score = cls_pred[0][idx][label].item<float>();
if (cur_score > score)
{
score = cur_score;
cur_label = label;
}
}
if (score > threshold)
{
//std::cout << "label:" << cur_label << " score:" << score << std::endl;
auto cur_dis = dis_pred[0][idx].contiguous();
const float* bbox_pred = cur_dis.data<float>();
results[cur_label].push_back(this->disPred2Bbox(bbox_pred, cur_label, score, col, row, stride));
// debug_heatmap.at<cv::Vec3b>(row, col)[0] = 255;
// cv::imshow("debug", debug_heatmap);
}
}
if (score > threshold)
{
//std::cout << "label:" << cur_label << " score:" << score << std::endl;
auto cur_dis = dis_pred[0][idx].contiguous();
const float* bbox_pred = cur_dis.data<float>();
results[cur_label].push_back(this->disPred2Bbox(bbox_pred, cur_label, score, col, row, stride));
// debug_heatmap.at<cv::Vec3b>(row, col)[0] = 255;
// cv::imshow("debug", debug_heatmap);
}
total_idx += feature_h * feature_w;
}
// cv::waitKey(0);
}
Expand Down
6 changes: 3 additions & 3 deletions demo_libtorch/nanodet_libtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class NanoDet

private:
torch::Tensor preprocess(cv::Mat& image);
void decode_infer(torch::Tensor& cls_pred, torch::Tensor& dis_pred, int stage_idx, float threshold, std::vector<std::vector<BoxInfo>>& results);
void decode_infer(torch::Tensor& cls_pred, torch::Tensor& dis_pred, float threshold, std::vector<std::vector<BoxInfo>>& results);
BoxInfo disPred2Bbox(const float*& dfl_det, int label, float score, int x, int y, int stride);
static void nms(std::vector<BoxInfo>& result, float nms_threshold);
std::vector<int> strides_{ 8, 16, 32 };
int input_size_ = 320;
std::vector<int> strides_{ 8, 16, 32, 64 };
int input_size_ = 416;
int num_class_ = 80;
int reg_max_ = 7;

Expand Down

0 comments on commit 6797732

Please sign in to comment.