Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpt-1b fails with mpt_model_load: unknown tensor 'transformer.blocks.0.attn.k_ln.weight' in model file #499

Open
OlegJakushkin opened this issue Sep 1, 2023 · 4 comments

Comments

@OlegJakushkin
Copy link

I want to load mpt-1b-redpajama-200b-dolly.

I converted it to ggml with:

!git clone https://github.com/ggerganov/ggml
!cd ggml && rm -rf ./build && mkdir build ; cd build && cmake .. && make -j32
!cd ggml && cd build && python3 ../examples/mpt/convert-h5-to-ggml.py ../../mpt-1b-redpajama-200b-dolly 1

which logged

* Loading part: pytorch_model.bin
Processing variable: transformer.blocks.0.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.0.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.0.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.0.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.1.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.1.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.1.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.1.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.10.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.10.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.10.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.10.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.11.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.11.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.11.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.11.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.12.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.12.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.12.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.12.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.13.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.13.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.13.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.13.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.14.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.14.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.14.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.14.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.15.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.15.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.15.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.15.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.16.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.16.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.16.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.16.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.17.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.17.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.17.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.17.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.18.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.18.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.18.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.18.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.19.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.19.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.19.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.19.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.2.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.2.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.2.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.2.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.20.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.20.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.20.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.20.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.21.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.21.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.21.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.21.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.22.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.22.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.22.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.22.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.23.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.23.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.23.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.23.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.3.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.3.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.3.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.3.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.4.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.4.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.4.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.4.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.5.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.5.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.5.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.5.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.6.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.6.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.6.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.6.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.7.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.7.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.7.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.7.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.8.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.8.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.8.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.8.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.9.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.9.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.9.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.9.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.ln_f.weight with shape:  (2048,) -> float32
Processing variable: transformer.wte.weight with shape:  (50432, 2048) -> float16
Done. Output file: ../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin

to do that I changed 2 lines in convert-h5-to-ggml

#fout.write(struct.pack("f", hparams["attn_config"]["alibi_bias_max"] or 2048))
fout.write(struct.pack("f", 2048))
#fout.write(struct.pack("f", hparams["attn_config"]["clip_qkv"] or 0.0))
fout.write(struct.pack("f", 0.0))

Now I try to run it with:

!cd ggml && cd build && ./bin/mpt -m  ../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin -p "I believe the meaning of life is" -t 8 -n 64

Getting

main: seed      = 1693544259
main: n_threads = 8
main: n_batch   = 8
main: n_ctx     = 512
main: n_predict = 64

mpt_model_load: loading model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin' - please wait ...
mpt_model_load: d_model        = 2048
mpt_model_load: max_seq_len    = 2048
mpt_model_load: n_ctx          = 512
mpt_model_load: n_heads        = 16
mpt_model_load: n_layers       = 24
mpt_model_load: n_vocab        = 50432
mpt_model_load: alibi_bias_max = 2048.000000
mpt_model_load: clip_qkv       = 0.000000
mpt_model_load: ftype          = 1
mpt_model_load: qntvr          = 0
mpt_model_load: ggml ctx size = 2597.45 MB
mpt_model_load: memory_size =    96.00 MB, n_mem = 12288
mpt_model_load: unknown tensor 'transformer.blocks.0.attn.k_ln.weight' in model file
main: failed to load model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin'
mpt_model_load: 

Could you please tell me what steps shall be performed to enable model loading (as it shall be quite similar to MPT)?

@OlegJakushkin
Copy link
Author

7b version and 1b version seem to have identical code around self.k_ln = layernorm_class(self.d_model, device=device) so wondering what could have gone wrong?

@OlegJakushkin
Copy link
Author

OlegJakushkin commented Sep 2, 2023

So here is what I tried It compiles and runs yet it fails in terms of output quality due to the difference in QK LayerNorm (two layers) used MPT-1b and missing in MPT-7b.

Blocks look like:

transformer.blocks.0.ln_1.weight torch.Size([2048]) torch.bfloat16
transformer.blocks.0.attn.Wqkv.weight torch.Size([6144, 2048]) torch.bfloat16
transformer.blocks.0.attn.q_ln.weight torch.Size([2048]) torch.bfloat16
transformer.blocks.0.attn.k_ln.weight torch.Size([2048]) torch.bfloat16
transformer.blocks.0.attn.out_proj.weight torch.Size([2048, 2048]) torch.bfloat16
transformer.blocks.0.ln_2.weight torch.Size([2048]) torch.bfloat16
transformer.blocks.0.mlp.mlp_up.weight torch.Size([8192, 2048]) torch.bfloat16
transformer.blocks.0.mlp.mlp_down.weight torch.Size([2048, 8192]) torch.bfloat16

q_ln and k_ln need to be wired in a different way to how current implementation moves from Wqkv to out_proj somewhere in L538-L615 lines range

@OlegJakushkin
Copy link
Author

OlegJakushkin commented Sep 2, 2023

Fixed it here. Now MPT 1b works correctly, yet all layers of MPT 1b are named differently so it is not clearly mergeable... So add a new folder with only MPT 1b or make the MPT sample to autodetect?

Example output 32bit:

main: seed      = 1693677012
main: n_threads = 32
main: n_batch   = 8
main: n_ctx     = 512
main: n_predict = 64

mpt_model_load: loading model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f32.bin' - please wait ...
mpt_model_load: d_model        = 2048
mpt_model_load: max_seq_len    = 2048
mpt_model_load: n_ctx          = 512
mpt_model_load: n_heads        = 16
mpt_model_load: n_layers       = 24
mpt_model_load: n_vocab        = 50432
mpt_model_load: alibi_bias_max = 8.000000
mpt_model_load: clip_qkv       = 0.000000
mpt_model_load: ftype          = 0
mpt_model_load: qntvr          = 0
mpt_model_load: ggml ctx size = 5098.85 MB
mpt_model_load: memory_size =    96.00 MB, n_mem = 12288
mpt_model_load: ........................ done
mpt_model_load: model size =  5002.76 MB / num tensors = 194
extract_tests_from_file : No test file found.
test_gpt_tokenizer : 0 tests failed out of 0 tests.

main: temp           = 0.800
main: top_k          = 50432
main: top_p          = 1.000
main: repeat_last_n  = 64
main: repeat_penalty = 1.020

main: number of tokens in prompt = 27
main: token[0] =     50
main: token[1] =     27
main: token[2] =   6758
main: token[3] =    637
main: token[4] =   5686
main: token[5] =    275
main: token[6] =  30953
main: token[7] =    369
main: token[8] =    253
main: token[9] =   3347
main: token[10] =    273
main: token[11] =    247
main: token[12] =  49298
main: token[13] =    382
main: token[14] =    275
main: token[15] =    253
main: token[16] =  21946
main: token[17] =   5553
main: token[18] =  37970
main: token[19] =    556
main: token[20] =    644
main: token[21] =  25546
main: token[22] =    323
main: token[23] =   8676
main: token[24] =  24962
main: token[25] =     84
main: token[26] =     32

Q: Which man born in 1932 was the son of a percussionist in the CBS radio orchestra has been nominated for 53 Oscars?
A: Hedy Lamarr. She married a musician and won an Oscar for Best Actress. The following pages provide background on her life before marriage, before she met her husband but also after marrying that man; finally, we find out why she was just selected as the greatest woman of all time by Forbes Magazine


main: sampled tokens =       64
main:  mem per token =   305340 bytes
main:      load time =  1328.41 ms
main:    sample time =   377.23 ms / 5.89 ms per token
main:      eval time =  9536.84 ms / 105.96 ms per token
main:     total time = 11404.35 ms

Example output 16bit:

main: seed      = 1693677080
main: n_threads = 32
main: n_batch   = 8
main: n_ctx     = 512
main: n_predict = 64

mpt_model_load: loading model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin' - please wait ...
mpt_model_load: d_model        = 2048
mpt_model_load: max_seq_len    = 2048
mpt_model_load: n_ctx          = 512
mpt_model_load: n_heads        = 16
mpt_model_load: n_layers       = 24
mpt_model_load: n_vocab        = 50432
mpt_model_load: alibi_bias_max = 8.000000
mpt_model_load: clip_qkv       = 0.000000
mpt_model_load: ftype          = 1
mpt_model_load: qntvr          = 0
mpt_model_load: ggml ctx size = 2597.85 MB
mpt_model_load: memory_size =    96.00 MB, n_mem = 12288
mpt_model_load: ........................ done
mpt_model_load: model size =  2501.76 MB / num tensors = 194
extract_tests_from_file : No test file found.
test_gpt_tokenizer : 0 tests failed out of 0 tests.

main: temp           = 0.800
main: top_k          = 50432
main: top_p          = 1.000
main: repeat_last_n  = 64
main: repeat_penalty = 1.020

main: number of tokens in prompt = 27
main: token[0] =     50
main: token[1] =     27
main: token[2] =   6758
main: token[3] =    637
main: token[4] =   5686
main: token[5] =    275
main: token[6] =  30953
main: token[7] =    369
main: token[8] =    253
main: token[9] =   3347
main: token[10] =    273
main: token[11] =    247
main: token[12] =  49298
main: token[13] =    382
main: token[14] =    275
main: token[15] =    253
main: token[16] =  21946
main: token[17] =   5553
main: token[18] =  37970
main: token[19] =    556
main: token[20] =    644
main: token[21] =  25546
main: token[22] =    323
main: token[23] =   8676
main: token[24] =  24962
main: token[25] =     84
main: token[26] =     32

Q: Which man born in 1932 was the son of a percussionist in the CBS radio orchestra has been nominated for 53 Oscars?
Answer: Harry Belafonte was born in New York and as a child attended school at night while his parents worked. He earned an undergraduate degree from Simmons College, followed by a teaching credential from Columbia University before turning to acting on Broadway. Then he became associated with the CBS radio orchestra where he started singing when


main: sampled tokens =       64
main:  mem per token =   317628 bytes
main:      load time =   686.81 ms
main:    sample time =   377.17 ms / 5.89 ms per token
main:      eval time =  5091.44 ms / 56.57 ms per token
main:     total time =  6240.22 ms

@ggerganov
Copy link
Owner

Maybe try to make the MPT example auto detect.
I guess in long term we should just add MPT support to llama.cpp.
For example, here is ongoing work to add Starcoder in llama.cpp: ggerganov/llama.cpp#3187
This can be used as an example of what needs to be done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants