Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multimodal models (LLaVA) #3436

Merged
merged 36 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
59aa1ac
WIP: start implementing LLaVA
monatis Oct 2, 2023
0f0e7c6
rm scratch buf for now, will revert after cleanup
monatis Oct 2, 2023
7e9120f
LLaVA image encoder is working. will combine with llama
monatis Oct 2, 2023
d37ed47
Add llava inference code, but it's buggy. debugging
monatis Oct 3, 2023
8690f42
LLaVA is working e2e, needs to optimize memory allocation + cleanup
monatis Oct 7, 2023
94eeac3
Use ggml_allocr + rm unnecessary code
monatis Oct 8, 2023
0c2bd79
fix: crlf -> lf
monatis Oct 8, 2023
204d08b
fix: new line at EoF
monatis Oct 8, 2023
95da79e
fix: trailing whitespace
monatis Oct 8, 2023
2a04d0b
Merge branch 'master' into llava
monatis Oct 8, 2023
444dbce
Add readme
monatis Oct 9, 2023
8af7e21
Update readme
monatis Oct 9, 2023
54495c9
Some cleanup
monatis Oct 9, 2023
9b0ec4d
Are you happy editorconfig?
monatis Oct 9, 2023
8278a73
rm unused batch image preprocessing
monatis Oct 9, 2023
d78e816
rm unused import
monatis Oct 9, 2023
4759bfd
fix: rm designated initializers
monatis Oct 9, 2023
325d240
introduce pad-to-square mode for non-square images
monatis Oct 9, 2023
d75a031
are you happy editorconfig?
monatis Oct 9, 2023
ae01c85
gitignore /llava
monatis Oct 9, 2023
5009ae9
Handle cases where image file does not exist
monatis Oct 9, 2023
96171de
add llava target to Makefile
monatis Oct 9, 2023
d640aae
add support for 13b model variant
monatis Oct 10, 2023
587bde8
Maybe seed is unlucky?
monatis Oct 11, 2023
f1564bb
Merge branch 'master' into llava
monatis Oct 11, 2023
ab21587
Check if apples are compared to apples
monatis Oct 11, 2023
0409ae0
are you happy editorconfig?
monatis Oct 11, 2023
f0f7834
Use temperature = 0.1 by default
monatis Oct 11, 2023
2bc1710
command line: use gpt_params_parse()
monatis Oct 11, 2023
1403d87
Merge master and fix conflicts
monatis Oct 11, 2023
dc913ea
minor
monatis Oct 12, 2023
56ccf97
handle default n_predict
monatis Oct 12, 2023
e9534ea
fix typo
monatis Oct 12, 2023
346e3c1
Merge branch 'master' into llava
ggerganov Oct 12, 2023
4bc5c9c
llava : code formatting, rename files, fix compile warnings
ggerganov Oct 12, 2023
0bd7e69
do not use Wno-cast-qual for MSVC
monatis Oct 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
LLaVA image encoder is working. will combine with llama
  • Loading branch information
monatis committed Oct 2, 2023
commit 7e9120f7b1b86b76f300c6162eeceede92ed3e99
15 changes: 11 additions & 4 deletions examples/llava/clip-test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "clip.h"
#include <stdio.h>
#include <stdlib.h>

int main(int argc, char ** argv) {
const char * model_path = argv[1];
Expand All @@ -8,14 +9,20 @@ int main(int argc, char ** argv) {

auto ctx_clip = clip_model_load(model_path, 1);
clip_image_u8 img;
//clip_tokens tokens;
//clip_tokenize(ctx_clip, text, &tokens);
//float vec[512];
//clip_text_encode(ctx_clip, 4, &tokens, vec, false);
clip_image_f32 img_res;
clip_image_load_from_file(img_path, &img);
clip_image_preprocess(ctx_clip, &img, &img_res);
float * vec = (float *)malloc(4096 * 257 * sizeof(float));
clip_image_encode(ctx_clip, 4, &img_res, vec, false);

/*
float score;
clip_compare_text_and_image(ctx_clip, 4, text, &img, &score);
printf("score: %f\n", score);
*/

clip_free(ctx_clip);
free(vec);


return 0;
Expand Down
112 changes: 72 additions & 40 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_DESCRIPTION "general.description"
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
Expand Down Expand Up @@ -77,6 +78,7 @@ static std::string format(const char * fmt, ...) {
#define TN_LN_POST "%s.post_ln.%s"
#define TN_TEXT_PROJ "text_projection.weight"
#define TN_VIS_PROJ "visual_projection.weight"
#define TN_LLAVA_PROJ "llava_projector.%s"

//
// utilities to get data from a gguf file
Expand Down Expand Up @@ -221,6 +223,10 @@ struct clip_vision_model {
struct ggml_tensor * post_ln_b;

struct ggml_tensor * projection;

// LLaVA projection
struct ggml_tensor * llava_proj_w;
struct ggml_tensor * llava_proj_b;
};

// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
Expand All @@ -240,6 +246,7 @@ struct clip_buffer {
struct clip_ctx {
bool has_text_encoder = false;
bool has_vision_encoder = false;
bool has_llava_projector = false;
struct clip_text_model text_model;
struct clip_vision_model vision_model;
struct clip_vocab vocab;
Expand Down Expand Up @@ -270,16 +277,17 @@ size_t get_mem_req_by_size(struct clip_ctx * ctx) {
if (vision_hparams->patch_size == 32) { // patch size = 32
return 96 * mb;
} else { // patch size = 16
return 256 * mb;
return 128 * mb;
}
case 197: // base or large, text-only
return 16 * mb;
return 96 * mb;
case 589: // large, two-tower
case 392: // large, vision-only
if (n_positions == 257) { // input image size = 224
return 60 * mb;
case 375: // large, LLaVA encoder
if (vision_hparams->image_size == 224) { // input image size = 224
return 1200 * mb;
} else { // input image size = 336
return 96 * mb;
return 1800 * mb;
}
case 909: // huge, two-tower
case 520: // huge, vision-only
Expand Down Expand Up @@ -313,6 +321,7 @@ size_t get_scr_buf_req_by_size(struct clip_ctx * ctx) {
return 32 * mb;
case 589:
case 392:
case 377:
if (n_positions <= 257) {
return 96 * mb;
} else {
Expand Down Expand Up @@ -406,12 +415,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
idx = get_key_idx(ctx, KEY_HAS_VIS_ENC);
new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx);

idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ);
if (idx != -1) {
new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx);
}

idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);

if (verbosity >= 1) {
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
printf("%s: model size: %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0));
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
}
Expand Down Expand Up @@ -556,10 +571,14 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
vision_model.class_embedding = get_tensor(new_clip->ctx, TN_CLASS_EMBD);
vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v"));
vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight"));
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ);
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) {
vision_model.llava_proj_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "weight"));
vision_model.llava_proj_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "bias"));
} else {
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ);
}
vision_model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = vision_model.layers[il];
Expand Down Expand Up @@ -1004,8 +1023,9 @@ bool clip_text_encode(const clip_ctx * ctx, const int n_threads, const clip_toke
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
}
ggml_graph_compute(&gf, &cplan);
*/
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
*/

ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

// print
#ifdef CLIP_DEBUG
Expand Down Expand Up @@ -1053,11 +1073,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
printf("used_mem = %zu\n", ggml_used_mem(ctx0));
#endif
memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim);
/*

/*
if (cplan.work_size != 0) {
free(cplan.work_data);
}
*/
*/

ggml_free(ctx0);

Expand Down Expand Up @@ -1254,50 +1275,60 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
embeddings = cur;
}

// get the output of cls token, e.g., 0th index
struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
for (int b = 0; b < batch_size; b++) {
ggml_set_i32_1d(cls, b, b * num_positions);
}
embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls);

// post-layernorm
{
embeddings = ggml_norm(ctx0, embeddings, eps);
//ggml_set_scratch(ctx0, {0, 0, nullptr});

embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings),
ggml_repeat(ctx0, model.post_ln_b, embeddings));
}
struct ggml_tensor * output = NULL;
if (ctx->has_llava_projector) {
output = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
embeddings = ggml_mul_mat(ctx0, model.llava_proj_w, embeddings);
output = ggml_add(ctx0, ggml_repeat(ctx0, model.llava_proj_b, embeddings), embeddings);
} else {
// get the output of cls token, e.g., 0th index
struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
for (int b = 0; b < batch_size; b++) {
ggml_set_i32_1d(cls, b, b * num_positions);
}
embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls);

//ggml_set_scratch(ctx0, {0, 0, nullptr});
// post-layernorm
{
embeddings = ggml_norm(ctx0, embeddings, eps);

// final visual projection
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings),
ggml_repeat(ctx0, model.post_ln_b, embeddings));
}

// normalize output embeddings
struct ggml_tensor * output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);
// final visual projection
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);

// normalize output embeddings
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);

for (int b = 0; b < batch_size; b++) {
struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
if (normalize) {
ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embedding)));
embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
for (int b = 0; b < batch_size; b++) {
struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
if (normalize) {
ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embedding)));
embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
}
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
}
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
}
ggml_set_name(output, "check");

// run the computation
ggml_build_forward_expand(&gf, output);

/*
ggml_cplan cplan = ggml_graph_plan(&gf, n_threads);
cplan.work_size *= batch_size;
if (cplan.work_size != 0) {
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
}
ggml_graph_compute(&gf, &cplan);
*/
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
*/

ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

// print
#ifdef CLIP_DEBUG
Expand Down Expand Up @@ -1347,11 +1378,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
#endif

memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size);
/*

/*
if (cplan.work_size != 0) {
free(cplan.work_data);
}
*/
*/

ggml_free(ctx0);

Expand Down
35 changes: 19 additions & 16 deletions examples/llava/convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,35 @@
TEXT = "clip.text"
VISION = "clip.vision"


def k(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch)


def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
if name in (
"logit_scale",
"text_model.embeddings.position_ids",
"vision_model.embeddings.position_ids",
):
return True
if name == "visual_projection.weight" and has_llava:

if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]:
return True

if name.startswith("v") and not has_vision:
return True

if name.startswith("t") and not has_text:
return True

return False


def get_tensor_name(name: str) -> str:
if "projection" in name:
return name

return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")


Expand Down Expand Up @@ -64,19 +67,22 @@ def bytes_to_unicode():
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py")
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
ap.add_argument("--text-only", action="store_true", required=False, help="Save a text-only model. It can't be used to encode images")
ap.add_argument("--vision-only", action="store_true", required=False, help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--text-only", action="store_true", required=False,
help="Save a text-only model. It can't be used to encode images")
ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--llava-projector", help="Path to projector.pt file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)

args = ap.parse_args()


if args.text_only and args.vision_only:
print("--text-only and --image-only arguments cannot be specified at the same time.")
exit(1)
Expand All @@ -91,7 +97,7 @@ def bytes_to_unicode():
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
vocab = json.load(f)
tokens = [key for key in vocab]

with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
config = json.load(f)
v_hparams = config["vision_config"]
Expand All @@ -108,7 +114,7 @@ def bytes_to_unicode():
if args.use_f32:
ftype = 0


model = CLIPModel.from_pretrained(dir_model)
processor = CLIPProcessor.from_pretrained(dir_model)

Expand Down Expand Up @@ -182,8 +188,6 @@ def bytes_to_unicode():
fout.add_bool("clip.use_gelu", use_gelu)




if has_llava_projector:
model.vision_model.encoder.layers.pop(-1)
projector = torch.load(args.llava_projector)
Expand All @@ -203,7 +207,7 @@ def bytes_to_unicode():

name = get_tensor_name(name)
data = data.squeeze().numpy()

n_dims = len(data.shape)

# ftype == 0 -> float32, ftype == 1 -> float16
Expand All @@ -229,8 +233,7 @@ def bytes_to_unicode():

print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
fout.add_tensor(name, data)




fout.write_header_to_file()
fout.write_kv_data_to_file()
Expand Down
Loading