Skip to content

Commit

Permalink
Replace plugin system with nodelets
Browse files Browse the repository at this point in the history
  • Loading branch information
akio committed Jul 2, 2018
1 parent ba48a29 commit f6acb3b
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 250 deletions.
9 changes: 4 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,13 @@ include_directories(include ${catkin_INCLUDE_DIRS} ${menoh_INCLUDE_DIRS})
# add_dependencies(menoh_ros ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})

## Declare a C++ executable
add_library(menoh_ros src/menoh_ros/nodelet.cpp)
add_library(menoh_ros
src/menoh_ros/nodelet.cpp
src/menoh_ros/io_nodelets.cpp
)
target_link_libraries(menoh_ros ${catkin_LIBRARIES} ${menoh_LIBRARIES})
#add_dependencies(menoh_nodelet libmenoh)

add_executable(menoh_node src/menoh_ros/node.cpp)
target_link_libraries(menoh_node ${catkin_LIBRARIES} ${Boost_LIBRARIES} menoh_ros)


## Add cmake target dependencies of the executable
## same as for the library above
# add_dependencies(menoh_ros_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
Expand Down
50 changes: 50 additions & 0 deletions include/menoh_ros/io_nodelets.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef MENOH_ROS_IO_NODELETS_H_
#define MENOH_ROS_IO_NODELETS_H_

#include <mutex>

#include "nodelet/nodelet.h"
#include "ros/ros.h"
#include "std_msgs/Float32MultiArray.h"
#include "sensor_msgs/Image.h"

namespace menoh_ros {


class ImageInputNodelet : public nodelet::Nodelet {
public:
ImageInputNodelet() = default;

~ImageInputNodelet() override = default;

void onInit() override;

void imageCallback(const sensor_msgs::Image::ConstPtr& msg);

private:
ros::Publisher pub_;
ros::Subscriber sub_;
double scale_{};
int32_t input_size_{};
};

class CategoryOutputNodelet : public nodelet::Nodelet {
public:
CategoryOutputNodelet() = default;

~CategoryOutputNodelet() override = default;

void onInit() override;

void resultCallback(const std_msgs::Float32MultiArray::ConstPtr& msg);

private:
ros::Subscriber sub_;
ros::Publisher pub_;
std::string category_names_path_;
};


} // namespace menoh_ros

#endif /*(MENOH_ROS_IO_NODELETS_H_*/
72 changes: 1 addition & 71 deletions include/menoh_ros/nodelet.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,70 +15,6 @@

namespace menoh_ros {

class InputPluginBase {
public:
InputPluginBase() = default;
virtual ~InputPluginBase() = default;

InputPluginBase(const InputPluginBase&) = delete;
InputPluginBase& operator=(const InputPluginBase&) = delete;

virtual void initialize(ros::NodeHandle& nh,
std::vector<int32_t>& dst_dims) = 0;

virtual bool execute(menoh::variable var) = 0;
};

//
class OutputPluginBase {
public:
OutputPluginBase() = default;
virtual ~OutputPluginBase() = default;

OutputPluginBase(const OutputPluginBase&) = delete;
OutputPluginBase& operator=(const OutputPluginBase&) = delete;

virtual void initialize(ros::NodeHandle& nh) = 0;

virtual void execute(menoh::variable var) = 0;
};



class VGG16InputPlugin : public InputPluginBase {
public:
VGG16InputPlugin() = default;

~VGG16InputPlugin() override = default;

void initialize(ros::NodeHandle& nh, std::vector<int32_t>& dst_dims) override;

bool execute(menoh::variable var) override;

void inputCallback(const sensor_msgs::ImageConstPtr& msg);

private:
ros::Subscriber sub_;
sensor_msgs::ImageConstPtr latest_image_;
std::mutex image_mutex_;
double scale_;
int32_t input_size_;
};

class VGG16OutputPlugin : public OutputPluginBase {
public:
VGG16OutputPlugin() = default;

~VGG16OutputPlugin() override = default;

void initialize(ros::NodeHandle& nh) override;

void execute(menoh::variable var) override;

private:
ros::Publisher pub_;
std::string synset_words_path_;
};

class MenohNodelet : public nodelet::Nodelet {
public:
Expand All @@ -87,23 +23,17 @@ class MenohNodelet : public nodelet::Nodelet {
~MenohNodelet() override = default;

void onInit() override;
private:
void timerCallback(const ros::TimerEvent& event);

private:
void inputCallback(const std_msgs::Float32MultiArray::ConstPtr& msg);

std::unique_ptr<menoh::model> model_;

std::string backend_name_;

std::unique_ptr<InputPluginBase> input_plugin_;
std::unique_ptr<OutputPluginBase> output_plugin_;

std::string input_variable_name_;
std::string output_variable_name_;

ros::Timer timer_;

ros::Subscriber input_sub_;
ros::Publisher output_pub_;
};
Expand Down
25 changes: 16 additions & 9 deletions launch/vgg16.launch
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@
<arg if="$(arg debug)" name="launch_prefix" default="xterm -e gdb --args" />
<arg unless="$(arg debug)" name="launch_prefix" default="" />

<node name="menoh_nodelet_manager" type="nodelet" pkg="nodelet" args="manager" output="screen"/>

<node name="menoh" type="nodelet" pkg="nodelet" launch-prefix="$(arg launch_prefix)" output="screen"
args="standalone menoh_ros/MenohNodelet">
args="load menoh_ros/MenohNodelet menoh_nodelet_manager">
<param name="model" value="$(find menoh_ros)/data/VGG16.onnx" />
<param name="input_variable_name" type="string" value="140326425860192" />
<param name="output_variable_name" type="string" value="140326200803680" />
<rosparam param="input_dims">[1, 3, 224, 224]</rosparam>
</node>

<node name="image_input" type="nodelet" pkg="nodelet" output="screen"
args="load menoh_ros/ImageInputNodelet menoh_nodelet_manager">
<param name="input_size" type="int" value="224" />
<param name="sysnet_words_path" type="string" value="$(find menoh_ros)/data/synset_words.txt" />
<rosparam param="input_map">
input:
var: "140326425860192"
size: 224
output:
var: "140326200803680"
</rosparam>
<remap from="/image_input/output" to="/menoh/input" />
</node>

<node name="category_output" type="nodelet" pkg="nodelet" output="screen"
args="load menoh_ros/CategoryOutputNodelet menoh_nodelet_manager">
<param name="category_names_path" type="string" value="$(find menoh_ros)/data/synset_words.txt" />
<remap from="/category_output/input" to="/menoh/output" />
</node>

</launch>
10 changes: 10 additions & 0 deletions nodelet_plugins.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,14 @@
<description>
</description>
</class>
<class name="menoh_ros/ImageInputNodelet" type="menoh_ros::ImageInputNodelet" base_class_type="nodelet::Nodelet">
<description>
</description>
</class>

<class name="menoh_ros/CategoryOutputNodelet" type="menoh_ros::CategoryOutputNodelet" base_class_type="nodelet::Nodelet">
<description>
</description>
</class>
</library>

142 changes: 142 additions & 0 deletions src/menoh_ros/io_nodelets.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#include "menoh_ros/io_nodelets.h"

#include "pluginlib/class_list_macros.h"

#include "opencv2/opencv.hpp"
#include "cv_bridge/cv_bridge.h"
#include "std_msgs/String.h"

namespace menoh_ros {

cv::Mat crop_and_resize(cv::Mat mat, cv::Size const& size) {
auto short_edge = std::min(mat.size().width, mat.size().height);
cv::Rect roi;
roi.x = (mat.size().width - short_edge) / 2;
roi.y = (mat.size().height - short_edge) / 2;
roi.width = roi.height = short_edge;
cv::Mat cropped = mat(roi);
cv::Mat resized;
cv::resize(cropped, resized, size);
return resized;
}

std::vector<float> reorder_to_nchw(cv::Mat const& mat) {
assert(mat.channels() == 3);
std::vector<float> data(mat.channels() * mat.rows * mat.cols);
for(int y = 0; y < mat.rows; ++y) {
for(int x = 0; x < mat.cols; ++x) {
// INFO cv::imread loads image BGR
for(int c = 0; c < mat.channels(); ++c) {
data[c * (mat.rows * mat.cols) + y * mat.cols + x] =
static_cast<float>(
mat.data[y * mat.step + x * mat.elemSize() + c]);
}
}
}
return data;
}

template <typename InIter>
std::vector<typename std::iterator_traits<InIter>::difference_type>
extract_top_k_index_list(
InIter first, InIter last,
typename std::iterator_traits<InIter>::difference_type k) {
using diff_t = typename std::iterator_traits<InIter>::difference_type;
std::priority_queue<
std::pair<typename std::iterator_traits<InIter>::value_type, diff_t>>
q;
for(diff_t i = 0; first != last; ++first, ++i) {
q.push({*first, i});
}
std::vector<diff_t> indices;
for(diff_t i = 0; i < k; ++i) {
indices.push_back(q.top().second);
q.pop();
}
return indices;
}


std::vector<std::string> load_category_list(std::string const& synset_words_path) {
std::ifstream ifs(synset_words_path);
if(!ifs) {
throw std::runtime_error("File open error: " + synset_words_path);
}
std::vector<std::string> categories;
std::string line;
while(std::getline(ifs, line)) {
categories.push_back(std::move(line));
}
return categories;
}

void ImageInputNodelet::onInit() {
auto private_nh = getPrivateNodeHandle();
// "input image width and height size"
private_nh.param<int>("input_size", input_size_, 224);
auto height = input_size_;
auto width = input_size_;

private_nh.param("scale", scale_, 1.0);
const int batch_size = 1;
const int channel_num = 3;

sub_ = private_nh.subscribe("input", 1, &ImageInputNodelet::imageCallback, this);
pub_ = private_nh.advertise<std_msgs::Float32MultiArray>("output", 1);
}

void ImageInputNodelet::imageCallback(const sensor_msgs::Image::ConstPtr& msg) {
auto cv_image = cv_bridge::toCvShare(msg);
auto image_mat = cv_image->image;
auto height = input_size_;
auto width = input_size_;

image_mat = crop_and_resize(std::move(image_mat), cv::Size(width, height));

std_msgs::Float32MultiArray tensor_msg;
tensor_msg.data = reorder_to_nchw(image_mat);
tensor_msg.layout.data_offset = 0;
tensor_msg.layout.dim.resize(4);
tensor_msg.layout.dim[0].label = "batch";
tensor_msg.layout.dim[0].size = 1;
tensor_msg.layout.dim[1].label = "channel";
tensor_msg.layout.dim[1].size = image_mat.channels();
tensor_msg.layout.dim[2].label = "height";
tensor_msg.layout.dim[2].size = image_mat.cols;
tensor_msg.layout.dim[3].label = "width";
tensor_msg.layout.dim[3].size = image_mat.rows;
pub_.publish(tensor_msg);
}


void CategoryOutputNodelet::onInit() {
auto private_nh = getPrivateNodeHandle();
private_nh.param<std::string>("category_names_path", category_names_path_, "not set");
sub_ = private_nh.subscribe("input", 1, &CategoryOutputNodelet::resultCallback, this);
pub_ = private_nh.advertise<std_msgs::String>("output", 1);
}


void CategoryOutputNodelet::resultCallback(const std_msgs::Float32MultiArray::ConstPtr& msg) {
// Get output
auto categories = load_category_list(category_names_path_);
auto top_k = 5;
auto top_k_indices = extract_top_k_index_list(
begin(msg->data),
end(msg->data),
top_k);
ROS_INFO_STREAM("top " << top_k << " categories:");
for(auto ki : top_k_indices) {
ROS_INFO_STREAM(" " << ki << " " << msg->data[ki] << " "
<< categories.at(ki));
}

std_msgs::String result_msg;
result_msg.data = categories.at(top_k_indices[0]);
pub_.publish(result_msg);
}

} // namespace menoh_ros

PLUGINLIB_EXPORT_CLASS(menoh_ros::ImageInputNodelet, nodelet::Nodelet);
PLUGINLIB_EXPORT_CLASS(menoh_ros::CategoryOutputNodelet, nodelet::Nodelet);
Loading

0 comments on commit f6acb3b

Please sign in to comment.