Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ADD_REL_POS perf in SAM by doing it inplace #466

Merged
merged 7 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 67 additions & 132 deletions examples/sam/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,41 @@ struct sam_image_f32 {
std::vector<float> data;
};

void ggml_sam_sin(const int n, float * dst, const float * src) {
for (int i = 0; i < n; ++i) {
dst[i] = sinf(src[i]);
void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(userdata == NULL);
GGML_ASSERT(ggml_are_same_shape(dst, src));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src));

const float * src_data = ggml_get_data_f32(src);
float * dst_data = ggml_get_data_f32(dst);

const int ne = (int)ggml_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = std::min(ie0 + dr, ne);

for (int i = ie0; i < ie1; ++i) {
dst_data[i] = sinf(src_data[i]);
}
}

void ggml_sam_cos(const int n, float * dst, const float * src) {
for (int i = 0; i < n; ++i) {
dst[i] = cosf(src[i]);
void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(userdata == NULL);
GGML_ASSERT(ggml_are_same_shape(dst, src));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src));

const float * src_data = ggml_get_data_f32(src);
float * dst_data = ggml_get_data_f32(dst);

const int ne = (int)ggml_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = std::min(ie0 + dr, ne);

for (int i = ie0; i < ie1; ++i) {
dst_data[i] = cosf(src_data[i]);
}
}

Expand Down Expand Up @@ -888,13 +914,6 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
}
}

// key + value memory
{
// const auto & hparams = model.hparams;

// TODO
}

// load weights
{
int n_tensors = 0;
Expand Down Expand Up @@ -1037,8 +1056,8 @@ bool sam_fill_dense_pe(
// concat
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
{
struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);
struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);

cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);

Expand All @@ -1059,6 +1078,28 @@ bool sam_fill_dense_pe(
return true;
}

struct ggml_tensor* sam_layer_norm_2d(
struct ggml_context * ctx0,
struct ggml_tensor * layer,
int n_channels,
struct ggml_tensor * w,
struct ggml_tensor * b) {
// LayerNorm2d
// normalize along channel dimmension
// TODO: better implementation
layer = ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3))),
2, 0, 1, 3);

layer = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer),
layer),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer));

return layer;
}

bool sam_encode_image(
const sam_model & model,
sam_state & state,
Expand Down Expand Up @@ -1228,7 +1269,7 @@ bool sam_encode_image(
0, 2, 1, 3));
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);

struct ggml_tensor * attn = ggml_add_rel_pos(ctx0, KQ_scaled, rel_w, rel_h);
struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);

struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);

Expand Down Expand Up @@ -1306,37 +1347,11 @@ bool sam_encode_image(

cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);

// LayerNorm2d
{
// normalize along channel dimmension
// TODO: better implementation
cur = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3))),
2, 0, 1, 3));

cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_0_w, 1, 1, n_enc_out_chans), cur),
cur),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_0_b, 1, 1, n_enc_out_chans), cur));
}
cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b);

cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur);

// LayerNorm2d
{
// normalize along channel dimmension
// TODO: better implementation
cur = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3))),
2, 0, 1, 3));

cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_1_w, 1, 1, n_enc_out_chans), cur),
cur),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_1_b, 1, 1, n_enc_out_chans), cur));
}
cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b);

// TODO: avoid copy
cur = ggml_cpy(ctx0, cur, state.embd_img);
Expand All @@ -1349,6 +1364,8 @@ bool sam_encode_image(
ggml_build_forward_expand(&gf, cur);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

//ggml_graph_print(&gf);

ggml_free(ctx0);
return true;
}
Expand Down Expand Up @@ -1423,8 +1440,8 @@ bool sam_encode_prompt(
// concat
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
{
struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);
struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);

cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]);

Expand Down Expand Up @@ -1462,74 +1479,6 @@ bool sam_encode_prompt(
// run the computation
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

// print
{
// auto print_t_f32 = [&](struct ggml_tensor * t) {
// float * data = (float *)t->data;
// printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
// printf("data: ");
// for (int i = 0; i < std::min((int) t->ne[0], 256); i++) {
// printf("%f ", data[i]);
// }
// printf("\n");
// //for (int y = 0; y < 64; ++y) {
// // for (int x = 0; x < 64; ++x) {
// // printf("%5.2f ", data[y*64 + x]);
// // }
// // printf("\n");
// //}
// //printf("\n");
// // for (int y = 0; y < 64; ++y) {
// // for (int x = 0; x < 64; ++x) {
// // printf("%5.2f ", data[255*64*64 + y*64 + x]);
// // }
// // printf("\n");
// // }
// // printf("\n");
// //for (int y = 0; y < 64; ++y) {
// // for (int x = 0; x < 64; ++x) {
// // printf("%5.2f ", data[(y*64 + x)*768 + 231]);
// // }
// // printf("\n");
// //}
// //printf("\n");
// double sum = 0.0;
// for (int i = 0; i < ggml_nelements(t); i++) {
// sum += data[i];
// }
// printf("sum: %f\n", sum);
// };

// auto print_t_f16 = [&](struct ggml_tensor * t) {
// ggml_fp16_t * data = (ggml_fp16_t *)t->data;
// printf("dims: %jd %jd %jd %jd f16\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
// printf("data: ");
// for (int i = 0; i < std::min((int) t->ne[0], 256); i++) {
// printf("%f ", ggml_fp16_to_fp32(data[i]));
// }
// printf("\n");
// for (int y = 0; y < 14; ++y) {
// for (int x = 0; x < 14; ++x) {
// printf("%7.4f ", ggml_fp16_to_fp32(data[(y*14 + x)*64 + 23]));
// }
// printf("\n");
// }
// printf("\n");
// double sum = 0.0;
// for (int i = 0; i < ggml_nelements(t); i++) {
// sum += ggml_fp16_to_fp32(data[i]);
// }
// printf("sum: %f\n", sum);
// };

// auto * t = ggml_get_tensor(ctx0, "check");
// if (t->type == GGML_TYPE_F32) {
// print_t_f32(t);
// } else {
// print_t_f16(t);
// }
}

//printf("used_mem = %zu\n", ggml_used_mem(ctx0));

ggml_free(ctx0);
Expand Down Expand Up @@ -1595,7 +1544,7 @@ struct ggml_tensor* sam_decode_mask_transformer_attn(

struct ggml_tensor * KQV_merged = ggml_cont(ctx0, ggml_transpose(ctx0, KQV));
KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3));
KQV_merged = ggml_cont(ctx0, ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]));
KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]);
KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged);
KQV_merged = ggml_add(ctx0,
ggml_repeat(ctx0, attn.out_b, KQV_merged),
Expand Down Expand Up @@ -1859,21 +1808,7 @@ bool sam_decode_mask(
// ConvTranspose2d
keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_0_b, keys), keys);

// LayerNorm2d
{
// normalize along channel dimmension
// TODO: better implementation
keys = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, keys, 1, 2, 0, 3))),
2, 0, 1, 3));

keys = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_1_w, 1, 1, n_img_embd), keys),
keys),
ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_1_b, 1, 1, n_img_embd), keys));
}
keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b);

// GELU activation
keys = ggml_gelu(ctx0, keys);
Expand All @@ -1898,7 +1833,7 @@ bool sam_decode_mask(

struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding);
masks = ggml_cont(ctx0, ggml_transpose(ctx0, masks)); // TODO: Shouldn't be needed
masks = ggml_cont(ctx0, ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]));
masks = ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]);

// Generate mask quality predictions
// ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146
Expand Down Expand Up @@ -1941,7 +1876,7 @@ bool sam_decode_mask(
bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
if (state.low_res_masks->ne[2] == 0) return true;
if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
printf("Error: number of masks (%jd) does not match number of iou predictions (%jd)\n", state.low_res_masks->ne[2], state.iou_predictions->ne[0]);
printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
return false;
}

Expand Down
7 changes: 7 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,12 +1384,19 @@ extern "C" {
int kh);

// used in sam

GGML_API struct ggml_tensor * ggml_add_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * pw,
struct ggml_tensor * ph);

GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * pw,
struct ggml_tensor * ph);

// custom operators

typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
Expand Down
Loading
Loading