diff --git a/README.md b/README.md index 0f0b46a67..d32dc27e0 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Tensor library for machine learning -***Note that this project is under active development. \ +***Note that this project is under active development. Some of the development is currently happening in the [llama.cpp](https://github.com/ggerganov/llama.cpp) and [whisper.cpp](https://github.com/ggerganov/whisper.cpp) repos*** ## Features @@ -39,6 +39,7 @@ Some of the development is currently happening in the [llama.cpp](https://github - [X] Example of ChatGLM inference [li-plus/chatglm.cpp](https://github.com/li-plus/chatglm.cpp) - [X] Example of Stable Diffusion inference [leejet/stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) - [X] Example of Qwen inference [QwenLM/qwen.cpp](https://github.com/QwenLM/qwen.cpp) +- [X] Example of superpoint inference [examples/superpoint]() - [X] Example of YOLO inference [examples/yolo](https://github.com/ggerganov/ggml/tree/master/examples/yolo) - [X] Example of ViT inference [staghado/vit.cpp](https://github.com/staghado/vit.cpp) - [X] Example of multiple LLMs inference [foldl/chatllm.cpp](https://github.com/foldl/chatllm.cpp) @@ -51,8 +52,8 @@ With ggml you can efficiently run [Whisper](examples/whisper) inference on the C Memory requirements: | Model | Disk | Mem | -| --- | --- | --- | -| tiny | 75 MB | ~280 MB | +| ------ | ------ | ------- | +| tiny | 75 MB | ~280 MB | | base | 142 MB | ~430 MB | | small | 466 MB | ~1.0 GB | | medium | 1.5 GB | ~2.6 GB | @@ -92,13 +93,13 @@ python3 ../examples/gpt-2/convert-cerebras-to-ggml.py /path/to/Cerebras-GPT-111M The inference speeds that I get for the different models on my 32GB MacBook M1 Pro are as follows: | Model | Size | Time / Token | -| --- | --- | --- | -| GPT-2 | 117M | 5 ms | -| GPT-2 | 345M | 12 ms | -| GPT-2 | 774M | 23 ms | -| GPT-2 | 1558M | 42 ms | -| --- | --- | --- | -| GPT-J | 6B | 125 ms | +| ----- | ----- | ------------ | +| GPT-2 | 117M | 5 ms | +| GPT-2 | 345M | 12 ms | +| GPT-2 | 774M | 23 ms | +| GPT-2 | 1558M | 42 ms | +| --- | --- | --- | +| GPT-J | 6B | 125 ms | For more information, checkout the corresponding programs in the [examples](examples) folder. @@ -127,6 +128,7 @@ cmake -DGGML_CUBLAS=ON -DCMAKE_CUDA_COMPILER=/usr/local/cuda-12.1/bin/nvcc .. ```bash cmake -DGGML_CLBLAST=ON .. ``` + ## Compiling for Android Download and unzip the NDK from this download [page](https://developer.android.com/ndk/downloads). Set the NDK_ROOT_PATH environment variable or provide the absolute path to the CMAKE_ANDROID_NDK in the command below. diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 666821611..850bea8ef 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -24,5 +24,6 @@ add_subdirectory(whisper) add_subdirectory(mnist) add_subdirectory(sam) add_subdirectory(yolo) +add_subdirectory(superpoint) add_subdirectory(simple) add_subdirectory(magika) diff --git a/examples/superpoint/CMakeLists.txt b/examples/superpoint/CMakeLists.txt new file mode 100644 index 000000000..c6944eeaf --- /dev/null +++ b/examples/superpoint/CMakeLists.txt @@ -0,0 +1,6 @@ +# +# yolov3-tiny + +set(TEST_TARGET superpoint) +add_executable(${TEST_TARGET} run_superpoint.cpp superpoint-image.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common) diff --git a/examples/superpoint/README.md b/examples/superpoint/README.md new file mode 100644 index 000000000..8d3fbfbaf --- /dev/null +++ b/examples/superpoint/README.md @@ -0,0 +1,63 @@ +# Superpoint.cpp + +This project shows how to implement Superpoint point extraction and description with ggml using pretrained model weights. + +The feature is no thirdparty dependencies needed~ + +If you want to integrate superpoint to your project and refuse to use thirdparty libs like libtorch, tensorrt, etc, superpoint.cpp is an option! + +TODO: + +* image has to been preprocessed to size of 480,640 +* image loading is complex and dirty +* acceleration ... + +Download the model weights:(optional) + +```bash +$ wget https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/superpoint_v1.pth + +``` + +compile the project and generate the executable file + +```bash +$ mkdir build +$ cd build +$ cmake .. +$ make +$ mv bin/superpoint ../examples/superpoint + +``` + +Convert the weights to GGUF format (optional): since the superpoint.gguf is uploaded to the folder superpoint, this step could be skipped. + +```bash +$ cd /examples/superpoint +$ ./convert-pth-ggml.py +``` + +inference + +```bash +$ ./superpoint -i dog_color.jpg +``` + +# Result + +feature extration + +![yolodog](result.jpg) + +matching performance + +![matches](matches.png) + +# Reference + +https://github.com/ggerganov/ggml + +https://github.com/magicleap/SuperPointPretrainedNetwork + + +https://github.com/adityamwagh/SuperSLAM diff --git a/examples/superpoint/convert_pth_ggml.py b/examples/superpoint/convert_pth_ggml.py new file mode 100644 index 000000000..19fa57576 --- /dev/null +++ b/examples/superpoint/convert_pth_ggml.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +''' +convert superpoint pth paramters to gguf parameters + +''' + +import sys +import gguf +import numpy as np +from demo_superpoint import * +from benchmark import * + + + +def list_params(net): + params_list = [] + list_vars = net.state_dict() + for name in list_vars.keys(): + data = list_vars[name].numpy() + print("Processing variable: " + name + " with shape: ", data.shape) + params_list.append((name, data)) + return params_list + + + +def onnx_export(net): + # Let's create a dummy input tensor + dummy_input = torch.randn(1, 1,480, 640, requires_grad=True) + + # Export the model + torch.onnx.export(net, # model being run + dummy_input, # model input (or a tuple for multiple inputs) + "superpoint.onnx", # where to save the model + export_params=True, # store the trained parameter weights inside the model file + opset_version=10, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names = ['modelInput'], # the model's input names + output_names = ['output1', 'output2'], # the model's output names + dynamic_axes={'modelInput' : {0 : 'batch_size'}, # variable length axes + 'modelOutput' : {0 : 'batch_size'}}) + return + +def isBias(name): + return (name.strip().split(".")[1] == "bias") + +def isConv(name): + return (name.strip().split(".")[1] == "weight") + +def save_conv2d_layer( name , conv, index, gguf_writer): + layername = "l" + str(index) + "_weights" + print(f"save {layername} with shape {conv.shape}") + ## ggml doesn't support f32 convolution yet, use f16 instead + + flat_conv = conv.astype(np.float16).flatten() + # print(type(conv)) + # exit(0) + gguf_writer.add_tensor(layername, flat_conv, raw_shape= conv.shape) + return + +def save_bias_layer( name ,biases, index, gguf_writer): + filters = biases.shape[0] + layername = "l" + str(index) + "_biases" + print(f"save {layername} with shape {biases.shape}") + gguf_writer.add_tensor(layername, biases, raw_shape=(1, filters, 1, 1)) + return + + + +if __name__ == '__main__': + #hyper parameters provided by superpoint# + weights_path = 'superpoint_v1.pth' + nms_dist = 4 + conf_thresh = 0.015 + nn_thresh = 0.7 + + outfile = "superpoint.gguf" + gguf_writer = gguf.GGUFWriter(outfile, 'superpoint') + + fe = SuperPointFrontend(weights_path= weights_path, + nms_dist= nms_dist, + conf_thresh= conf_thresh, + nn_thresh= nn_thresh, + cuda= False) + conv_list = list_params(fe.net) + conv_idx = 0 + bias_idx = 0 + for name, layer in conv_list: + print(f"processing {name}") + if(isConv(name)): + # if(conv_idx==10): + # print(f"name: {name}: {layer.flatten()[:10]}") + # if(conv_idx==11): + # print(f"name: {name}: {layer.flatten()[:10]}") + # # exit(0) + + + save_conv2d_layer(name, layer, conv_idx, gguf_writer) + conv_idx+=1 + elif(isBias(name)): + save_bias_layer(name, layer, bias_idx, gguf_writer) + bias_idx+=1 + + gguf_writer.write_header_to_file() + gguf_writer.write_kv_data_to_file() + gguf_writer.write_tensors_to_file() + gguf_writer.close() + print("In total {} conv layers, {} bias ".format(conv_idx, bias_idx)) + print("{} converted to {}".format(weights_path, outfile)) diff --git a/examples/superpoint/dog_color.jpg b/examples/superpoint/dog_color.jpg new file mode 100644 index 000000000..c7f27d998 Binary files /dev/null and b/examples/superpoint/dog_color.jpg differ diff --git a/examples/superpoint/matches.png b/examples/superpoint/matches.png new file mode 100644 index 000000000..fdb76e68d Binary files /dev/null and b/examples/superpoint/matches.png differ diff --git a/examples/superpoint/result.jpg b/examples/superpoint/result.jpg new file mode 100644 index 000000000..6f548ce28 Binary files /dev/null and b/examples/superpoint/result.jpg differ diff --git a/examples/superpoint/run_superpoint.cpp b/examples/superpoint/run_superpoint.cpp new file mode 100644 index 000000000..db0180ae8 --- /dev/null +++ b/examples/superpoint/run_superpoint.cpp @@ -0,0 +1,664 @@ +#include "ggml/ggml.h" +#include "superpoint-image.h" +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + + + +int clip(int val, int max) +{ + if (val < 0) + return 0; + return std::min(val, max - 1); +} + + +static std::vector softmax(const std::vector & logits) { + std::vector probs(logits.size()); + float max_logit = logits[0]; + for (float v : logits) max_logit = std::max(max_logit, v); + double sum_exp = 0.0; + for (size_t i = 0; i < logits.size(); i++) { + // Subtract the maximum logit value from the current logit value for numerical stability + const float logit = logits[i] - max_logit; + const float exp_logit = expf(logit); + sum_exp += exp_logit; + probs[i] = exp_logit; + } + for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; + return probs; +} + +static struct ggml_tensor* brute_permute(ggml_context * ctx, struct ggml_tensor *input, int d0, int d1, int d2, int d3) +{ + //assert it is contigous + //STEP 1: get each stride at src tensor + int dims[4]; + int strides[4]; + dims[d0] = input->ne[0]; + dims[d1] = input->ne[1]; + dims[d2] = input->ne[2]; + dims[d3] = input->ne[3]; + + //STEP 2: based on the permute result, recalcute the stride + //get element_size + strides[d0] = input->nb[0]/sizeof(float); + strides[d1] = input->nb[1]/sizeof(float); + strides[d2] = input->nb[2]/sizeof(float); + strides[d3] = input->nb[3]/sizeof(float); + //create a new tensor with the wanted shape + auto new_tensor = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, + dims[0], dims[1], dims[2], dims[3]); + + float* data_src = ggml_get_data_f32(input); + float* data_dst = ggml_get_data_f32(new_tensor); + //do the permutation + //use for loop, to reallocate the data + int cnt = 0; + for (int h =0; h ne[0]; + int h = input->ne[1]; + int c = input->ne[2]; + int num = w * h *c; + + + printf("before softmax Shape: %3d x %3d x %4d x %3d\n", w, h, c, (int)input->ne[3]); + //well, it seems like channel is at the 3rd channel, to iterate over it, shift it to the first dim +#define NMS_V1 +#ifdef NMS_V1 + auto _tensor = brute_permute(ctx, input,2,1,0,3); +#else + auto _tensor = brute_permute(ctx, input,1,2,0,3); + +#endif + int ne0 = _tensor->ne[0]; + int ne1 = _tensor->ne[1]; + int ne2 = _tensor->ne[2]; + int ne3 = _tensor->ne[3]; + int tensor_size = ne0 * ne1 * ne2 * ne3; + + struct ggml_tensor* nodust = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 64, _tensor->ne[1], _tensor->ne[2], (int)_tensor->ne[3]); + int channel = 65;//c + + for(int i =0; i < num/channel; i+=1) + { + int in_offset = i * channel; + int out_offset = i * (channel -1); + float* in_arr = ggml_get_data_f32(_tensor)+ in_offset; + float* out_arr = ggml_get_data_f32(nodust)+ out_offset; + std::vector in_vec(in_arr, in_arr + channel); + std::vector vec = softmax(in_vec); + std::copy(vec.begin(), vec.end() - 1, out_arr); + } + //TODO: the operation above seems not usual? + + printf("Final Shape: %3lld x %3lld x %4lld x %3d\n", nodust->ne[0], nodust->ne[1], nodust->ne[2], (int)nodust->ne[3]); + return nodust; +} + +/* +A interesting bug, you have to flatten the tensor(view it as 1d) to make permutation work... +*/ +static struct ggml_tensor* heatmap_semi(ggml_context * ctx, struct ggml_tensor *scores) +{ + int cell =8; + int Hc = scores->ne[2]; + int Wc = scores->ne[1]; + + scores = ggml_reshape_4d(ctx, scores, 8, 8, Wc, Hc);//h,w,8,8 + + auto heatmap = brute_permute(ctx, scores, 0,2,1,3); + + print_shape(100, heatmap); + // write_array("heatmap.txt", heatmap); + + + return heatmap; + +} + +void screen_keypoints(struct ggml_tensor* tensor, + int H, + int W, + int cell, + std::vector& pts /* OUTPUT */) +{ + float * data = ggml_get_data_f32(tensor); + + const int HO = H/cell; + const int WO = W/cell; + const int border_remove = 10; + const float conf_thresh = 0.015; + int output_semi_dim0 = cell * cell; + int output_semi_dim1 = HO; + int output_semi_dim2 = WO; + + + const int left_margin = border_remove; + const int right_margin = W - border_remove; + const int top_margin = border_remove; + const int bottom_margin = H - border_remove; + + auto get_coord = [H, W](int a, int b, int c, int &row, int &col) -> bool { + row = (b << 3) + (a >> 3); + col = (c << 3) + (a bitand 0x07); + return ((row >= 0) && (row < H) && (col >= 0) && (col < W)); + }; + auto get_offset = [HO, WO](int a, int b, int c) -> int { + return ( (a * HO * WO) + (b * WO) + c ); + }; + auto action_on_each = [&](int k, int i, int j) { + float prob = data[get_offset(k, i, j)]; + if (prob > conf_thresh) + { + int row = -1, col = -1; + get_coord(k, i, j, row, col); + + if ( row < top_margin || row >= bottom_margin || + col < left_margin || col >= right_margin ) + { + ; /* drop point along border */ + } + else + { + /* qualified points */ + pts.emplace_back( col /* x */, row /* y */, prob ); + } + } + }; + + pts.clear(); + pts.reserve(4000); + + /* measurement util */ + // Timer t; + + for (int k = 0; k < output_semi_dim0; k++ ) + { + for (int i = 0; i < output_semi_dim1; i++) + { + for (int j = 0; j < output_semi_dim2; j++) + { + action_on_each(k, i, j); + } + } + } + +} + + +void nms_fast(std::vector& corners, int H, int W, int keypoints_num, int nms_dist) +{ + if ( corners.empty() ) return; + + std::sort( corners.begin(), corners.end(), + [](const PointT& a, const PointT& b) -> bool { + return (a.conf > b.conf); } + ); + + const int pad = nms_dist; + const int grid_H = H + 2 * pad; + const int grid_W = W + 2 * pad; + std::vector> grid(grid_H, std::vector(grid_W, 0)); + + + for (auto & pt : corners) + { + const int row = pad + pt.y; + const int col = pad + pt.x; + grid[row][col] = 1; + } + + /* reserve survivors */ + std::vector survivors; + survivors.reserve(1000); + + for (auto & pt : corners) + { + const int row = pad + pt.y; + const int col = pad + pt.x; + + if (grid[row][col] == 1) + { + for(int c = col - pad; c< std::min(col + pad, grid_W); c++) + for (int r = row - pad; r< std::min(row + pad, grid_H); r++) + { + grid[r][c] = 0; + } + /* keep the corner */ + grid[row][col] == -1; + survivors.push_back(pt); + } + } + + corners.clear(); + corners = std::move(survivors); + + while (corners.size() > (size_t)keypoints_num) + { + corners.pop_back(); + } +} + + +static void postprocess_semi(ggml_context * ctx, superpoint_image & img, struct ggml_tensor *input, std::vector& pts) +{ + auto scores = softmax_semi(ctx, input); + printf("screen keypoints\n"); + //TODO: improve efficiency of permutation + scores = brute_permute(ctx, scores, 2,1,0,3); + // brute_permute() + screen_keypoints( scores, img.h, img.w, 8, pts); + nms_fast(pts, img.h, img.w, 4000, 4); + // printf("points size %d\n", pts.size()); + // write_points("points.txt", pts); + +} + +void normalize_keypoints(const std::vector &keypoints, std::vector &keypoints_norm, + int h, int w) +{ + for (auto &kp : keypoints) + { + PointT keypoint; + keypoint.conf = kp.conf; + + keypoint.x = (float)kp.x / (0.5 *w) - 1; + keypoint.y = (float)kp.y / (0.5 *h) - 1; + keypoints_norm.push_back(keypoint); + // std::cout<<"keypoint: "< &grid, + std::vector> &output, int dim, int h, int w) +{ + for (auto &g : grid) + { + double ix = ((g.x + 1) / 2) * (w - 1); + double iy = ((g.y + 1) / 2) * (h - 1); + // std::cout<<"ix: "< descriptor; + for (int i = 0; i < dim; ++i) + { + // 256x60x106 dhw + //TODO: check this index + // x * height * depth + y * depth + z + float nw_val = input[i * h * w + iy_nw * w + ix_nw]; + float ne_val = input[i * h * w + iy_ne * w + ix_ne]; + float sw_val = input[i * h * w + iy_sw * w + ix_sw]; + float se_val = input[i * h * w + iy_se * w + ix_se]; + descriptor.push_back(nw_val * nw + ne_val * ne + sw_val * sw + se_val * se); + } + output.push_back(descriptor); + } + // exit(0); +} + + + + +template +double vector_normalize(Iter_T first, Iter_T last) +{ + return sqrt(std::inner_product(first, last, first, 0.0)); +} + +void normalize_descriptors(std::vector> &dest_descriptors) +{ + for (auto &descriptor : dest_descriptors) + { + double norm_inv = 1.0 / vector_normalize(descriptor.begin(), descriptor.end()); + std::transform(descriptor.begin(), descriptor.end(), descriptor.begin(), + std::bind1st(std::multiplies(), norm_inv)); + } +} + +static void postprocess_desc(ggml_context * ctx, + superpoint_image & img, + struct ggml_tensor *desc, + std::vector& pts, + std::vector>& descriptors) +{ + std::vector keypoints_norm; + int h = img.h; + int w = img.w; + normalize_keypoints(pts, keypoints_norm, h, w); + float* desc_ptr = ggml_get_data_f32(desc); + int dim = 256; + grid_sample(desc_ptr, keypoints_norm, descriptors, dim, h/8., w/8.); + normalize_descriptors(descriptors); +} + + +struct conv2d_layer { + struct ggml_tensor * weights; + struct ggml_tensor * biases; + // struct ggml_tensor * scales; + // struct ggml_tensor * rolling_mean; + // struct ggml_tensor * rolling_variance; + int padding = 1; + bool batch_normalize = false; + bool activate = true; // true for relu, false for linear +}; + +struct superpoint_model +{ + int width = 640; + int height = 480; + std::vector conv2d_layers; + struct ggml_context * ctx; +}; + +static bool load_model(const std::string & fname, superpoint_model& model) { + struct gguf_init_params params = { + /*.no_alloc =*/ false, + /*.ctx =*/ &model.ctx, + }; + gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + if (!ctx) { + fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__); + return false; + } + model.width = 640; + model.height = 480; + //TODO: switch the size 13 too! + model.conv2d_layers.resize(12); + + for (int i = 0; i < (int)model.conv2d_layers.size(); i++) { + char name[256]; + snprintf(name, sizeof(name), "l%d_weights", i); + //ggml_fp32_to_fp16 + model.conv2d_layers[i].weights = ggml_get_tensor(model.ctx, name); + //The weights are loaded as fp16 + model.conv2d_layers[i].weights->type = ggml_type::GGML_TYPE_F16; + snprintf(name, sizeof(name), "l%d_biases", i); + model.conv2d_layers[i].biases = ggml_get_tensor(model.ctx, name); + } + + //layers without relu + model.conv2d_layers[9].activate = false; + model.conv2d_layers[9].padding = 0; + model.conv2d_layers[11].activate = false; + model.conv2d_layers[11].padding = 0; + + + + + + return true; +} + +static ggml_tensor * apply_conv2d(ggml_context * ctx, ggml_tensor * input, const conv2d_layer & layer) +{ + // struct ggml_tensor * result = ggml_conv_1d(ctx, layer.weights, input, 1, 1, 1); + struct ggml_tensor * result = ggml_conv_2d(ctx, layer.weights, input, 1, 1, layer.padding, layer.padding, 1, 1); + + result = ggml_add(ctx, result, ggml_repeat(ctx, layer.biases, result)); + if (layer.activate) { + //implement normal relu + result = ggml_relu(ctx, result); + } + return result; +} + +static void activate_array(float * x, const int n) +{ + // logistic activation + for (int i = 0; i < n; i++) { + x[i] = 1./(1. + exp(-x[i])); + } +} + + +static float get_color(int c, int x, int max) +{ + float colors[6][3] = { {1,0,1}, {0,0,1}, {0,1,1}, {0,1,0}, {1,1,0}, {1,0,0} }; + float ratio = ((float)x/max)*5; + int i = floor(ratio); + int j = ceil(ratio); + ratio -= i; + float r = (1-ratio) * colors[i][c] + ratio*colors[j][c]; + return r; +} + + + +void inference(superpoint_image & img, + const superpoint_model & model, + float thresh, + std::vector& pts, + std::vector>& descriptors) + +{ + //TODO: modifiy the size, bc it is too large + static size_t buf_size = 20000000 * sizeof(float) * 40; + static void * buf = malloc(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf, + /*.no_alloc =*/ false, + }; + struct ggml_context * ctx0 = ggml_init(params); + // model.ctx = ctx0; + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + // std::vector detections; + //reshape the image + superpoint_image sized = letterbox_image(img, model.width, model.height); + + //allovate datasize + struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, model.width, model.height, 1, 1); + std::memcpy(input->data, img.data.data(), ggml_nbytes(input)); + ggml_set_name(input, "input"); + print_shape(0, input); + + //x = self.relu(self.conv1a(x)) + struct ggml_tensor * result = apply_conv2d(ctx0, input, model.conv2d_layers[0]); + result = apply_conv2d(ctx0, result, model.conv2d_layers[1]); + result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + print_shape(2, result); + //x = self.relu(self.conv2a(x)) + result = apply_conv2d(ctx0, result, model.conv2d_layers[2]); + print_shape(3, result); + //x = self.relu(self.conv2b(x)) + result = apply_conv2d(ctx0, result, model.conv2d_layers[3]); + print_shape(4, result); + //x = self.pool(x) + result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + print_shape(5, result); + //x = self.relu(self.conv3a(x)) + result = apply_conv2d(ctx0, result, model.conv2d_layers[4]); + // for further connections + // struct ggml_tensor * layer_8 = result; + print_shape(6, result); + // x = self.relu(self.conv3b(x)) + result = apply_conv2d(ctx0, result, model.conv2d_layers[5]); + print_shape(7, result); + // x = self.pool(x) + result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + print_shape(8, result); + result = apply_conv2d(ctx0, result, model.conv2d_layers[6]); + print_shape(9, result); + // x = self.relu(self.conv4b(x)) + struct ggml_tensor * encoder_output = apply_conv2d(ctx0, result, model.conv2d_layers[7]); + + result = apply_conv2d(ctx0, encoder_output, model.conv2d_layers[8]); + print_shape(11, result); + struct ggml_tensor * semi = apply_conv2d(ctx0, result, model.conv2d_layers[9]); + print_shape(12, semi); + + result = apply_conv2d(ctx0, encoder_output, model.conv2d_layers[10]); + print_shape(13, result); + + struct ggml_tensor * desc = apply_conv2d(ctx0, result, model.conv2d_layers[11]); + print_shape(14, desc); + ggml_build_forward_expand(gf, semi); + ggml_build_forward_expand(gf, desc); + ggml_graph_compute_with_ctx(ctx0, gf, 1); + const int64_t t_start_ms = ggml_time_ms(); + postprocess_semi(ctx0, sized, semi, pts); + postprocess_desc(ctx0, sized, desc, pts, descriptors); + const int64_t t_detect_ms = ggml_time_ms() - t_start_ms; + printf("superpoint postprocessing time: %f sec.)\n", t_detect_ms / 1000.0f); + + + + +} + +struct superpoint_params { + float thresh = 0.5; + std::string model = "superpoint.gguf"; + std::string fname_inp = "dog_color.jpg"; + std::string fname_out = "result.jpg"; +}; + +void superpoint_print_usage(int argc, char ** argv, const superpoint_params & params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -th T, --thresh T detection threshold (default: %.2f)\n", params.thresh); + fprintf(stderr, " -m FNAME, --model FNAME\n"); + fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, " -i FNAME, --inp FNAME\n"); + fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str()); + fprintf(stderr, " -o FNAME, --out FNAME\n"); + fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str()); + fprintf(stderr, "\n"); +} + +bool superpoint_params_parse(int argc, char ** argv, superpoint_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-th" || arg == "--thresh") { + params.thresh = std::stof(argv[++i]); + } else if (arg == "-m" || arg == "--model") { + params.model = argv[++i]; + } else if (arg == "-i" || arg == "--inp") { + params.fname_inp = argv[++i]; + } else if (arg == "-o" || arg == "--out") { + params.fname_out = argv[++i]; + } else if (arg == "-h" || arg == "--help") { + superpoint_print_usage(argc, argv, params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + superpoint_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +int main(int argc, char *argv[]) +{ + ggml_time_init(); + superpoint_model model; + + superpoint_params params; + if (!superpoint_params_parse(argc, argv, params)) { + return 1; + } + if (!load_model(params.model, model)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + superpoint_image img(0,0,0); + if (!load_image(params.fname_inp.c_str(), img, true)) { + fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str()); + return 1; + } + + superpoint_image rgb_img(0,0,0); + if (!load_image(params.fname_inp.c_str(), rgb_img, false)) { + fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str()); + return 1; + } + + /*resize image to 640 480 currently is needed*/ + assert(img.w == 640); + assert(img.h == 480); + + printf("start inference\n"); + std::vector pts; + std::vector>descriptors; + const int64_t t_start_ms = ggml_time_ms(); + inference(img, model, params.thresh,pts, descriptors); + const int64_t t_detect_ms = ggml_time_ms() - t_start_ms; + printf("superpoint inference time: %f sec.)\n", t_detect_ms / 1000.0f); + + //dump data + write_points("points.txt", pts); + write_descriptors("descs.txt", descriptors); + + //visualize points + for(auto& pt:pts) + { + float red = get_color(2,0,5); + float green = get_color(1,0,5); + float blue = get_color(0,0,5); + draw_point(rgb_img, pt.y, pt.x, red, green, blue); + } + if (!save_image(rgb_img, params.fname_out.c_str(), 80)) { + fprintf(stderr, "%s: failed to save image to '%s'\n", __func__, params.fname_out.c_str()); + return 1; + } + ggml_free(model.ctx); + + return 0; +} diff --git a/examples/superpoint/superpoint-image.cpp b/examples/superpoint/superpoint-image.cpp new file mode 100644 index 000000000..9b6bf23b1 --- /dev/null +++ b/examples/superpoint/superpoint-image.cpp @@ -0,0 +1,305 @@ +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" +#define STB_IMAGE_WRITE_IMPLEMENTATION +#include "stb_image_write.h" +#include +#include "superpoint-image.h" + + + +static void draw_box(superpoint_image & a, int x1, int y1, int x2, int y2, float r, float g, float b) +{ + if (x1 < 0) x1 = 0; + if (x1 >= a.w) x1 = a.w-1; + if (x2 < 0) x2 = 0; + if (x2 >= a.w) x2 = a.w-1; + + if (y1 < 0) y1 = 0; + if (y1 >= a.h) y1 = a.h-1; + if (y2 < 0) y2 = 0; + if (y2 >= a.h) y2 = a.h-1; + + for (int i = x1; i <= x2; ++i){ + a.data[i + y1*a.w + 0*a.w*a.h] = r; + a.data[i + y2*a.w + 0*a.w*a.h] = r; + + a.data[i + y1*a.w + 1*a.w*a.h] = g; + a.data[i + y2*a.w + 1*a.w*a.h] = g; + + a.data[i + y1*a.w + 2*a.w*a.h] = b; + a.data[i + y2*a.w + 2*a.w*a.h] = b; + } + for (int i = y1; i <= y2; ++i){ + a.data[x1 + i*a.w + 0*a.w*a.h] = r; + a.data[x2 + i*a.w + 0*a.w*a.h] = r; + + a.data[x1 + i*a.w + 1*a.w*a.h] = g; + a.data[x2 + i*a.w + 1*a.w*a.h] = g; + + a.data[x1 + i*a.w + 2*a.w*a.h] = b; + a.data[x2 + i*a.w + 2*a.w*a.h] = b; + } +} + +void draw_box_width(superpoint_image & a, int x1, int y1, int x2, int y2, int w, float r, float g, float b) +{ + for (int i = 0; i < w; ++i) { + draw_box(a, x1+i, y1+i, x2-i, y2-i, r, g, b); + } +} + +bool save_image(const superpoint_image & im, const char *name, int quality) +{ + uint8_t *data = (uint8_t*)calloc(im.w*im.h*im.c, sizeof(uint8_t)); + for (int k = 0; k < im.c; ++k) { + for (int i = 0; i < im.w*im.h; ++i) { + data[i*im.c+k] = (uint8_t) (255*im.data[i + k*im.w*im.h]); + } + } + int success = stbi_write_jpg(name, im.w, im.h, im.c, data, quality); + free(data); + if (!success) { + fprintf(stderr, "Failed to write image %s\n", name); + return false; + } + return true; +} + +/** + * force image to be in grayscale +*/ +void load_data(int w, int h, int c, const uint8_t * data, std::vector& img_data, bool is_grey) +{ + if(is_grey) + { + img_data.resize(w*h); + } + else + { + img_data.resize(c*w*h); + + } + + + if (c ==1) + { + for (int k = 0; k < c; ++k){ + for (int j = 0; j < h; ++j){ + for (int i = 0; i < w; ++i){ + //rgb, I guess + int dst_index = i + w*j + w*h*k; + int src_index = k + c*i + c*w*j; + img_data[dst_index] = (float)data[src_index]/255.; + } + } + } + } + else if (c == 3) + { + if(is_grey) + { + for (int j = 0; j < h; ++j) + { + for (int i = 0; i < w; ++i) + { + //rgb, I guess + int dst_index = i + w*j; + // int src_index = k + c*i + c*w*j; + int red = 0 + 3 * i + 3 * w * j; + int green = 1 + 3 * i + 3 * w * j; + int blue = 2 + 3 * i + 3 * w * j; + + float grey = 0.299 * data[red] + 0.587 * data[green] + 0.114 * data[blue]; + // float grey = 0.333 * data[red] + 0.333 * data[green] + 0.333 * data[blue]; + + img_data[dst_index] = grey/255.; + + } + } + + } + else + { + for (int k = 0; k < c; ++k){ + for (int j = 0; j < h; ++j){ + for (int i = 0; i < w; ++i){ + int dst_index = i + w*j + w*h*k; + int src_index = k + c*i + c*w*j; + img_data[dst_index] = (float)data[src_index]/255.; + } + } + } + } + + + + } +} + +/*In superpoint, image is read as grey and normalized to 0 -- 1*/ +bool load_image(const char *fname, superpoint_image & img, bool be_grey) +{ + //assert channel is 3 + int w, h, c; + uint8_t * data = nullptr; + if(be_grey) + { + data = stbi_load(fname, &w, &h, &c, 1); + } + else + { + data = stbi_load(fname, &w, &h, &c, 3); + } + img.h = h; + img.w = w; + img.c = be_grey?1:3; + + if(c == 1) + { + printf("load grey image\n"); + } + else if (c == 3) + { + printf("load RGB image\n"); + + } + if (!data) { + return false; + } + if(c == 3 && be_grey) + { + load_data(w,h,1,data, img.data, be_grey); + } + else if (c ==3 &&(!be_grey)) + { + load_data(w,h,3,data, img.data, be_grey); + + } + + stbi_image_free(data); + return true; +} + +/* +https://en.wikipedia.org/wiki/Grayscale#Colorimetric_(perceptual_luminance-preserving)_conversion_to_grayscale +*/ + + +static superpoint_image resize_image(const superpoint_image & im, int w, int h) +{ + superpoint_image resized(w, h, im.c); + superpoint_image part(w, im.h, im.c); + float w_scale = (float)(im.w - 1) / (w - 1); + float h_scale = (float)(im.h - 1) / (h - 1); + for (int k = 0; k < im.c; ++k){ + for (int r = 0; r < im.h; ++r) { + for (int c = 0; c < w; ++c) { + float val = 0; + if (c == w-1 || im.w == 1){ + val = im.get_pixel(im.w-1, r, k); + } else { + float sx = c*w_scale; + int ix = (int) sx; + float dx = sx - ix; + val = (1 - dx) * im.get_pixel(ix, r, k) + dx * im.get_pixel(ix+1, r, k); + } + part.set_pixel(c, r, k, val); + } + } + } + for (int k = 0; k < im.c; ++k){ + for (int r = 0; r < h; ++r){ + float sy = r*h_scale; + int iy = (int) sy; + float dy = sy - iy; + for (int c = 0; c < w; ++c){ + float val = (1-dy) * part.get_pixel(c, iy, k); + resized.set_pixel(c, r, k, val); + } + if (r == h-1 || im.h == 1) continue; + for (int c = 0; c < w; ++c){ + float val = dy * part.get_pixel(c, iy+1, k); + resized.add_pixel(c, r, k, val); + } + } + } + return resized; +} + +static void embed_image(const superpoint_image & source, superpoint_image & dest, int dx, int dy) +{ + for (int k = 0; k < source.c; ++k) { + for (int y = 0; y < source.h; ++y) { + for (int x = 0; x < source.w; ++x) { + float val = source.get_pixel(x, y, k); + dest.set_pixel(dx+x, dy+y, k, val); + } + } + } +} + +superpoint_image letterbox_image(const superpoint_image & im, int w, int h) +{ + int new_w = im.w; + int new_h = im.h; + if (((float)w/im.w) < ((float)h/im.h)) { + new_w = w; + new_h = (im.h * w)/im.w; + } else { + new_h = h; + new_w = (im.w * h)/im.h; + } + superpoint_image resized = resize_image(im, new_w, new_h); + superpoint_image boxed(w, h, im.c); + boxed.fill(0.5); + embed_image(resized, boxed, (w-new_w)/2, (h-new_h)/2); + return boxed; +} + +static superpoint_image tile_images(const superpoint_image & a, const superpoint_image & b, int dx) +{ + if (a.w == 0) { + return b; + } + superpoint_image c(a.w + b.w + dx, (a.h > b.h) ? a.h : b.h, a.c); + c.fill(1.0f); + embed_image(a, c, 0, 0); + embed_image(b, c, a.w + dx, 0); + return c; +} + +static superpoint_image border_image(const superpoint_image & a, int border) +{ + superpoint_image b(a.w + 2*border, a.h + 2*border, a.c); + b.fill(1.0f); + embed_image(a, b, border, border); + return b; +} + +superpoint_image get_label(const std::vector & alphabet, const std::string & label, int size) +{ + size = size/10; + size = std::min(size, 7); + superpoint_image result(0,0,0); + for (int i = 0; i < (int)label.size(); ++i) { + int ch = label[i]; + superpoint_image img = alphabet[size*128 + ch]; + result = tile_images(result, img, -size - 1 + (size+1)/2); + } + return border_image(result, (int)(result.h*.25)); +} + +void draw_point(superpoint_image & im, int y, int x, float r, float g, float b) +{ + for(int row = y -2; row +#include +#include + +typedef struct PointType +{ + float x; /* col */ + float y; /* row */ + float conf; + + PointType() = default; + + PointType(int x_, int y_, float c_) + : x(x_), y(y_), conf(c_) { } +} PointT; + + +struct superpoint_image { + int w, h, c; + std::vector data; + + superpoint_image() : w(0), h(0), c(0) {} + superpoint_image(int w, int h, int c) : w(w), h(h), c(c), data(w*h*c) {} + + float get_pixel(int x, int y, int c) const { + assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c); + return data[c*w*h + y*w + x]; + } + + void set_pixel(int x, int y, int c, float val) { + assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c); + data[c*w*h + y*w + x] = val; + } + + void add_pixel(int x, int y, int c, float val) { + assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c); + data[c*w*h + y*w + x] += val; + } + + void fill(float val) { + std::fill(data.begin(), data.end(), val); + } +}; + +bool load_image(const char *fname, superpoint_image & img, bool be_grey); + +void draw_point(superpoint_image & im, int row, int col, float r, float g, float b); + + +void load_data(int w, int h, int c, const uint8_t * data, std::vector& img_data, bool is_grey); + +void draw_box_width(superpoint_image & a, int x1, int y1, int x2, int y2, int w, float r, float g, float b); +superpoint_image letterbox_image(const superpoint_image & im, int w, int h); +bool save_image(const superpoint_image & im, const char *name, int quality); +superpoint_image get_label(const std::vector & alphabet, const std::string & label, int size); +void draw_label(superpoint_image & im, int row, int col, const superpoint_image & label, const float * rgb); diff --git a/examples/superpoint/superpoint.gguf b/examples/superpoint/superpoint.gguf new file mode 100644 index 000000000..1297967fe Binary files /dev/null and b/examples/superpoint/superpoint.gguf differ diff --git a/examples/superpoint/utils.h b/examples/superpoint/utils.h new file mode 100644 index 000000000..2410ab441 --- /dev/null +++ b/examples/superpoint/utils.h @@ -0,0 +1,130 @@ +#pragma once + + +#include "ggml/ggml.h" +#include "superpoint-image.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +static void print_shape(int layer, const ggml_tensor * t) +{ + printf("Layer %2d output shape: %3d x %3d x %4d x %3d\n", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]); +} + + +static void write_array(const std::string& fname, const ggml_tensor * t) +{ + GGML_ASSERT(ggml_is_contiguous(t)); + int size = t->ne[0] * t->ne[1] * t->ne[2] * t->ne[3]; + printf("write %lld data to %s\n", size, fname.c_str()); + // Open a file for writing + std::ofstream outfile(fname); + float* data = ggml_get_data_f32(t); + // Write the elements of the array to the file + if (outfile.is_open()) { + for (int i = 0; i < size; ++i) + { + float value = data[i]; + outfile << value << std::endl; + } + outfile.close(); + } +} + +static void write_points(const std::string& fname, const std::vector& pts) +{ + + printf("write %d points to %s\n", pts.size(), fname.c_str()); + // Open a file for writing + std::ofstream outfile(fname); + + if (outfile.is_open()) { + for (auto& pt: pts) + { + // float value = data[i]; + outfile << int(pt.x)<<" " << int(pt.y)<<" "<>& descs) +{ + + printf("write %d descs to %s\n", descs.size(), fname.c_str()); + // Open a file for writing + std::ofstream outfile(fname); + + if (outfile.is_open()) { + for (auto& desc: descs) + { + for(auto& digit: desc) + { + outfile << digit<<" "; + } + outfile<< std::endl; + // float value = data[i]; + } + outfile.close(); + } +} + +static void print_larger_number(struct ggml_tensor *input, float thre) +{ + // float thre = 0.015; + int num = input->ne[0] * input->ne[1] * input->ne[2] * input->ne[3]; + float* data = ggml_get_data_f32(input); + size_t cnt =0; + + for (size_t i = 0; i< num; i++) + { + float value = data[i]; + if (value > 0.015) + { + cnt++; + printf("index %zu: %f\n", cnt, value); + } + } +} + +static void print_data(struct ggml_tensor *input) +{ + int w = input->ne[0]; + int h = input->ne[1]; + int c = input->ne[2]; + printf("Shape: %3d x %3d x %4d x %3d\n", w, h, c, (int)input->ne[3]); + printf("nb: %3d x %3d x %4d x %3d\n", input->nb[0], input->nb[1], input->nb[2], (int)input->nb[3]); + + int num = 10; + if(input->type == GGML_TYPE_F16) + { + ggml_fp16_t* data = static_cast(ggml_get_data(input)); + // for (size_t i =0; i< w*h*c; i++) + for (size_t i = 0; i< num; i++) + { + float value = ggml_fp16_to_fp32(data[i]); + printf("index %zu: %f\n", i, value); + } + } + else + { + float* data = ggml_get_data_f32(input); + // for (size_t i =0; i< w*h*c; i++) + for (size_t i = 0; i< 0+num; i++) + { + float value = data[i]; + printf("index %zu: %f\n", i, value); + } + } +}