Skip to content

Commit

Permalink
yolo : add reading labels and alphabet from model file
Browse files Browse the repository at this point in the history
  • Loading branch information
katsu560 committed May 19, 2024
1 parent 193f48b commit 5f8f110
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 7 deletions.
27 changes: 26 additions & 1 deletion examples/yolo/yolo-image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,31 @@ bool load_image(const char *fname, yolo_image & img)
return true;
}

bool load_image_from_memory(const char *buffer, int len, yolo_image & img)
{
int w, h, c;
uint8_t * data = stbi_load_from_memory((uint8_t *)buffer, len, &w, &h, &c, 3);
if (!data) {
return false;
}
c = 3;
img.w = w;
img.h = h;
img.c = c;
img.data.resize(w*h*c);
for (int k = 0; k < c; ++k){
for (int j = 0; j < h; ++j){
for (int i = 0; i < w; ++i){
int dst_index = i + w*j + w*h*k;
int src_index = k + c*i + c*w*j;
img.data[dst_index] = (float)data[src_index]/255.;
}
}
}
stbi_image_free(data);
return true;
}

static yolo_image resize_image(const yolo_image & im, int w, int h)
{
yolo_image resized(w, h, im.c);
Expand Down Expand Up @@ -207,4 +232,4 @@ void draw_label(yolo_image & im, int row, int col, const yolo_image & label, con
}
}
}
}
}
1 change: 1 addition & 0 deletions examples/yolo/yolo-image.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct yolo_image {
};

bool load_image(const char *fname, yolo_image & img);
bool load_image_from_memory(const char *buffer, int len, yolo_image & img);
void draw_box_width(yolo_image & a, int x1, int y1, int x2, int y2, int w, float r, float g, float b);
yolo_image letterbox_image(const yolo_image & im, int w, int h);
bool save_image(const yolo_image & im, const char *name, int quality);
Expand Down
101 changes: 95 additions & 6 deletions examples/yolo/yolov3-tiny.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct yolo_model {
int height = 416;
std::vector<conv2d_layer> conv2d_layers;
struct ggml_context * ctx;
struct gguf_context * ggufctx;
};

struct yolo_layer {
Expand Down Expand Up @@ -71,6 +72,7 @@ static bool load_model(const std::string & fname, yolo_model & model) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
return false;
}
model.ggufctx = ctx;
model.width = 416;
model.height = 416;
model.conv2d_layers.resize(13);
Expand Down Expand Up @@ -100,6 +102,47 @@ static bool load_model(const std::string & fname, yolo_model & model) {
return true;
}

// istream from memory
#include <streambuf>
#include <istream>

struct membuf : std::streambuf {
membuf(const char * begin, const char * end) {
char * b(const_cast<char *>(begin));
char * e(const_cast<char *>(end));
this->begin = b;
this->end = e;
this->setg(b, b, e);
}

membuf(const char * base, size_t size) {
char * b(const_cast<char *>(begin));
this->begin = b;
this->end = b + size;
this->setg(b, b, end);
}

virtual pos_type seekoff(off_type off, std::ios_base::seekdir dir, std::ios_base::openmode which = std::ios_base::in) override {
if(dir == std::ios_base::cur) {
gbump(off);
} else if(dir == std::ios_base::end) {
setg(begin, end + off, end);
} else if(dir == std::ios_base::beg) {
setg(begin, begin + off, end);
}

return gptr() - eback();
}

virtual pos_type seekpos(std::streampos pos, std::ios_base::openmode mode) override {
return seekoff(pos - pos_type(off_type(0)), std::ios_base::beg, mode);
}

char * begin;
char * end;
};


static bool load_labels(const char * filename, std::vector<std::string> & labels)
{
std::ifstream file_in(filename);
Expand All @@ -114,6 +157,25 @@ static bool load_labels(const char * filename, std::vector<std::string> & labels
return true;
}

static bool load_labels_kv(const struct gguf_context * ctx, const char * filename, std::vector<std::string> & labels)
{
struct gguf_nobj nobj = gguf_find_name_nobj(ctx, filename);
if (nobj.n == 0) {
return false;
}
membuf buf(nobj.data, nobj.data + nobj.n);
std::istream file_in(&buf);
if (!file_in) {
return false;
}
std::string line;
while (std::getline(file_in, line)) {
labels.push_back(line);
}
GGML_ASSERT(labels.size() == 80);
return true;
}

static bool load_alphabet(std::vector<yolo_image> & alphabet)
{
alphabet.resize(8 * 128);
Expand All @@ -130,6 +192,27 @@ static bool load_alphabet(std::vector<yolo_image> & alphabet)
return true;
}

static bool load_alphabet_kv(const struct gguf_context * ctx, std::vector<yolo_image> & alphabet)
{
alphabet.resize(8 * 128);
for (int j = 0; j < 8; j++) {
for (int i = 32; i < 127; i++) {
char fname[256];
sprintf(fname, "data/labels/%d_%d.png", i, j);
struct gguf_nobj nobj = gguf_find_name_nobj(ctx, fname);
if (nobj.n == 0) {
fprintf(stderr, "Cannot find '%s'\n", fname);
return false;
}
if (!load_image_from_memory(nobj.data, nobj.n, alphabet[j*128 + i])) {
fprintf(stderr, "Cannot load '%s'\n", fname);
return false;
}
}
}
return true;
}

static ggml_tensor * apply_conv2d(ggml_context * ctx, ggml_tensor * input, const conv2d_layer & layer)
{
struct ggml_tensor * result = ggml_conv_2d(ctx, layer.weights, input, 1, 1, layer.padding, layer.padding, 1, 1);
Expand Down Expand Up @@ -503,14 +586,20 @@ int main(int argc, char *argv[])
return 1;
}
std::vector<std::string> labels;
if (!load_labels("data/coco.names", labels)) {
fprintf(stderr, "%s: failed to load labels from 'data/coco.names'\n", __func__);
return 1;
if (!load_labels_kv(model.ggufctx, "data/coco.names", labels)) {
fprintf(stderr, "%s: failed to load labels from 'data/coco.names' in model\n", __func__);
if (!load_labels("data/coco.names", labels)) {
fprintf(stderr, "%s: failed to load labels from 'data/coco.names'\n", __func__);
return 1;
}
}
std::vector<yolo_image> alphabet;
if (!load_alphabet(alphabet)) {
fprintf(stderr, "%s: failed to load alphabet\n", __func__);
return 1;
if (!load_alphabet_kv(model.ggufctx, alphabet)) {
fprintf(stderr, "%s: failed to load alphabet from model\n", __func__);
if (!load_alphabet(alphabet)) {
fprintf(stderr, "%s: failed to load alphabet\n", __func__);
return 1;
}
}
const int64_t t_start_ms = ggml_time_ms();
detect(img, model, params.thresh, labels, alphabet);
Expand Down

0 comments on commit 5f8f110

Please sign in to comment.