From c8c353f88d8892b5de23a4f622da761130840188 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 16 May 2024 20:16:35 +0200 Subject: [PATCH 01/27] Added initial support for DeepseekV2ForCausalLM. --- convert-hf-to-gguf.py | 74 +++++++ gguf-py/gguf/constants.py | 44 +++++ gguf-py/gguf/tensor_mapping.py | 29 ++- llama.cpp | 340 ++++++++++++++++++++++++++++++++- 4 files changed, 482 insertions(+), 5 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index cd875fa4af6af..ef1d9a2bc934b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2389,6 +2389,80 @@ def set_vocab(self, *args, **kwargs): self.gguf_writer.add_add_eos_token(True) +@Model.register("DeepseekV2ForCausalLM") +class DeepseekV2Model(Model): + model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def write_tensors(self): + super().write_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 978fcada3b42c..0a732022d0ee1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum): COMMAND_R = auto() DBRX = auto() OLMO = auto() + DEEPSEEK2 = auto() class MODEL_TENSOR(IntEnum): @@ -181,6 +182,12 @@ class MODEL_TENSOR(IntEnum): SSM_A = auto() SSM_D = auto() SSM_OUT = auto() + ATTN_Q_A = auto() + ATTN_Q_B = auto() + ATTN_KV_A_MQA = auto() + ATTN_KV_B = auto() + ATTN_Q_A_NORM = auto() + ATTN_KV_A_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -217,6 +224,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.DEEPSEEK2: "deepseek2", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -259,6 +267,12 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", + MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", + MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", + MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", + MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -743,6 +757,32 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } @@ -779,6 +819,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } # diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8e1cac9152f55..383d8440bb5d9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -255,6 +255,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2 ), # AWQ-activation gate @@ -283,6 +284,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2 ), # Feed-forward down @@ -317,6 +319,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2 ), MODEL_TENSOR.ATTN_Q_NORM: ( @@ -380,6 +383,30 @@ class TensorNameMap: "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", ), + + MODEL_TENSOR.ATTN_Q_A: ( + "model.layers.{bid}.self_attn.q_a_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_B: ( + "model.layers.{bid}.self_attn.q_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_MQA: ( + "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_B: ( + "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_A_NORM: ( + "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_NORM: ( + "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 + ), } mapping: dict[str, tuple[MODEL_TENSOR, str]] @@ -398,7 +425,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): if tensor not in MODEL_TENSORS[arch]: continue # TODO: make this configurable - n_experts = 60 + n_experts = 160 for xid in range(n_experts): tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) self.mapping[tensor_name] = (tensor, tensor_name) diff --git a/llama.cpp b/llama.cpp index 7d26966e49110..53959d83a1f2c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -110,7 +110,7 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 60 +#define LLAMA_MAX_EXPERTS 160 // // logging @@ -229,6 +229,7 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_DEEPSEEK2, LLM_ARCH_UNKNOWN, }; @@ -266,6 +267,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -476,6 +478,12 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1052,6 +1060,34 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_DEEPSEEK2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1895,6 +1931,8 @@ struct llama_layer { struct ggml_tensor * attn_k_norm_b; struct ggml_tensor * attn_out_norm; struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * attn_q_a_norm; + struct ggml_tensor * attn_kv_a_norm; // attention struct ggml_tensor * wq; @@ -1902,6 +1940,10 @@ struct llama_layer { struct ggml_tensor * wv; struct ggml_tensor * wo; struct ggml_tensor * wqkv; + struct ggml_tensor * wq_a; + struct ggml_tensor * wq_b; + struct ggml_tensor * wkv_a_mqa; + struct ggml_tensor * wkv_b; // attention bias struct ggml_tensor * bq; @@ -4261,6 +4303,11 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + model.type = e_model::MODEL_UNKNOWN; + } break; default: (void)0; } @@ -4920,8 +4967,6 @@ static bool llm_load_tensors( throw std::runtime_error("model has expert layers but no expert layers are used"); } - GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); - ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix); @@ -6060,6 +6105,67 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_DEEPSEEK2: + { + // TODO maybe move some of these to hparams + const uint32_t n_shared_experts = 2; + const uint32_t moe_intermediate_size = 1536; + const uint32_t q_lora_rank = 1536; + const uint32_t kv_lora_rank = 512; + const uint32_t first_k_dense_replace = 1; + + // kept original names of these parameters from HF transformers code for clarity + const uint32_t qk_rope_head_dim = hparams.n_rot; + const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); + + layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); + layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k}); + layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}); + layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (qk_nope_head_dim + hparams.n_embd_head_v)}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + if ((uint32_t) i < first_k_dense_replace) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } else { + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {moe_intermediate_size, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert}); + + // Shared expert branch + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { moe_intermediate_size * n_shared_experts, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts}); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6700,7 +6806,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); } @@ -10779,6 +10885,227 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_deepseek2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + // TODO maybe move some of these to hparams + const uint32_t first_k_dense_replace = 1; + const uint32_t kv_lora_rank = 512; + + // kept original names of these parameters from HF transformers code for clarity + const uint32_t qk_rope_head_dim = hparams.n_rot; + const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + struct ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = llm_build_norm(ctx0, q, hparams, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + + // split into {n_head * qk_nope_head_dim, n_tokens} + struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0); + // and {n_head * qk_rope_head_dim, n_tokens} + struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, qk_rope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * qk_nope_head_dim); + + q_nope = ggml_cont(ctx0, q_nope); + cb(q_nope, "q_nope", il); + + q_pe = ggml_cont(ctx0, q_pe); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + qk_rope_head_dim} * {n_embd, n_tokens} -> {kv_lora_rank + qk_rope_head_dim, n_tokens} + struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(compressed_kv_pe, "compressed_kv_pe", il); + + // split into {kv_lora_rank, n_tokens} + struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0); + // and {qk_rope_head_dim, n_tokens} + struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, qk_rope_head_dim, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank); + + k_pe = ggml_cont(ctx0, k_pe); + cb(k_pe, "k_pe", il); + + compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(compressed_kv, "compressed_kv", il); + + // {kv_lora_rank, n_head * (qk_nope_head_dim + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (qk_nope_head_dim + n_embd_head_v), n_tokens} + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv); + cb(kv, "kv", il); + + // split into {n_head * qk_nope_head_dim, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), 0); + // and {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * value_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * qk_nope_head_dim); + + value_states = ggml_dup(ctx0, value_states); + cb(value_states, "value_states", il); + + value_states = ggml_reshape_2d(ctx0, value_states, hparams.n_embd_head_v * n_head, n_tokens); + + q_pe = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, q_pe, qk_rope_head_dim, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + struct ggml_tensor * query_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); + cb(query_states, "query_states", il); + query_states = ggml_set_inplace(ctx0, query_states, q_nope, query_states->nb[1], query_states->nb[2], query_states->nb[3], 0); + query_states = ggml_set_inplace(ctx0, query_states, q_pe, query_states->nb[1], query_states->nb[2], query_states->nb[3], ggml_element_size(query_states) * qk_nope_head_dim); + + k_pe = ggml_repeat(ctx0, k_pe, q_pe); + cb(k_pe, "k_pe", il); + + struct ggml_tensor * key_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); + cb(key_states, "key_states", il); + key_states = ggml_set_inplace(ctx0, key_states, k_nope, key_states->nb[1], key_states->nb[2], key_states->nb[3], 0); + key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * qk_nope_head_dim); + + // TODO see if we can avoid these operations by permuting + // rows/columns of some model tensors during model conversion + query_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, query_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens))); + cb(query_states, "query_states", il); + + query_states = ggml_reshape_3d(ctx0, query_states, hparams.n_embd_head_k, n_head, n_tokens); + cb(query_states, "query_states", il); + + key_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, key_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens))); + cb(key_states, "key_states", il); + + key_states = ggml_reshape_3d(ctx0, key_states, hparams.n_embd_head_k, n_head, n_tokens); + cb(key_states, "key_states", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(hparams.n_embd_head_k)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + if ((uint32_t) il < first_k_dense_replace) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + cb, il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up_shexp, NULL, + model.layers[il].ffn_gate_shexp, NULL, + model.layers[il].ffn_down_shexp, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + moe_out = ggml_add(ctx0, moe_out, ffn_shexp); + cb(moe_out, "ffn_out", il); + + cur = moe_out; + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10993,6 +11320,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; + case LLM_ARCH_DEEPSEEK2: + { + result = llm.build_deepseek2(); + } break; default: GGML_ASSERT(false); } @@ -16008,6 +16339,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: case LLM_ARCH_OLMO: + case LLM_ARCH_DEEPSEEK2: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From 039896407afd40e54321d47c5063c46a52da3e01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 18 May 2024 11:38:07 +0200 Subject: [PATCH 02/27] Removed unnecessary tensor operations. --- llama.cpp | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/llama.cpp b/llama.cpp index 0ba5d2e21c34e..1da020b288585 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10939,13 +10939,9 @@ struct llm_build_context { // split into {n_head * qk_nope_head_dim, n_tokens} struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0); + cb(q_nope, "q_nope", il); // and {n_head * qk_rope_head_dim, n_tokens} struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, qk_rope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * qk_nope_head_dim); - - q_nope = ggml_cont(ctx0, q_nope); - cb(q_nope, "q_nope", il); - - q_pe = ggml_cont(ctx0, q_pe); cb(q_pe, "q_pe", il); // {n_embd, kv_lora_rank + qk_rope_head_dim} * {n_embd, n_tokens} -> {kv_lora_rank + qk_rope_head_dim, n_tokens} @@ -10954,10 +10950,9 @@ struct llm_build_context { // split into {kv_lora_rank, n_tokens} struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0); + cb(compressed_kv, "compressed_kv", il); // and {qk_rope_head_dim, n_tokens} struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, qk_rope_head_dim, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank); - - k_pe = ggml_cont(ctx0, k_pe); cb(k_pe, "k_pe", il); compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams, @@ -10971,16 +10966,20 @@ struct llm_build_context { // split into {n_head * qk_nope_head_dim, n_tokens} struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), 0); + cb(k_nope, "k_nope", il); + // and {n_head * n_embd_head_v, n_tokens} struct ggml_tensor * value_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * qk_nope_head_dim); + cb(value_states, "value_states", il); - value_states = ggml_dup(ctx0, value_states); + value_states = ggml_cont(ctx0, value_states); cb(value_states, "value_states", il); - value_states = ggml_reshape_2d(ctx0, value_states, hparams.n_embd_head_v * n_head, n_tokens); + value_states = ggml_view_2d(ctx0, value_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0); + cb(value_states, "value_states", il); q_pe = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, q_pe, qk_rope_head_dim, n_head, n_tokens), inp_pos, + ctx0, q_pe, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -10988,7 +10987,7 @@ struct llm_build_context { // shared RoPE key k_pe = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens), inp_pos, + ctx0, ggml_view_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -11007,20 +11006,6 @@ struct llm_build_context { key_states = ggml_set_inplace(ctx0, key_states, k_nope, key_states->nb[1], key_states->nb[2], key_states->nb[3], 0); key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * qk_nope_head_dim); - // TODO see if we can avoid these operations by permuting - // rows/columns of some model tensors during model conversion - query_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, query_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens))); - cb(query_states, "query_states", il); - - query_states = ggml_reshape_3d(ctx0, query_states, hparams.n_embd_head_k, n_head, n_tokens); - cb(query_states, "query_states", il); - - key_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, key_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens))); - cb(key_states, "key_states", il); - - key_states = ggml_reshape_3d(ctx0, key_states, hparams.n_embd_head_k, n_head, n_tokens); - cb(key_states, "key_states", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(hparams.n_embd_head_k)), cb, il); From b50c07c247488736112240e0381e42a8333aaea8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 18 May 2024 14:40:09 +0200 Subject: [PATCH 03/27] Added five new DeepSeek-V2-specific parameters: - leading_dense_block_count => hparams.n_leading_dense_layer, - expert_feed_forward_length => hparams.n_expert_ff, - expert_shared_count => hparams.n_expert_shared, - attention.q_lora_rank => hparams.n_lora_q, - attention.kv_lora_rank => hparams.n_lora_kv --- convert-hf-to-gguf.py | 13 ++++++--- gguf-py/gguf/constants.py | 5 ++++ gguf-py/gguf/gguf_writer.py | 15 +++++++++++ llama.cpp | 53 ++++++++++++++++++++++++++----------- 4 files changed, 66 insertions(+), 20 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 156ac0c5322cb..cb1f01549e196 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2401,7 +2401,16 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams + + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -2410,10 +2419,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) - self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["v_head_dim"]) - self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) - _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0a732022d0ee1..9b6e568475046 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -37,11 +37,14 @@ class LLM: CONTEXT_LENGTH = "{arch}.context_length" EMBEDDING_LENGTH = "{arch}.embedding_length" BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" EXPERT_COUNT = "{arch}.expert_count" EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" @@ -55,6 +58,8 @@ class Attention: LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" CAUSAL = "{arch}.attention.causal" + Q_LORA_RANK = "{arch}.attention.q_lora_rank" + KV_LORA_RANK = "{arch}.attention.kv_lora_rank" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d5e323a52ef14..e82f4e9ab1a06 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -376,9 +376,15 @@ def add_embedding_length(self, length: int) -> None: def add_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) + def add_leading_dense_block_count(self, length: int) -> None: + self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + def add_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_expert_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) @@ -409,6 +415,9 @@ def add_expert_count(self, count: int) -> None: def add_expert_used_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count) + def add_expert_shared_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) @@ -418,6 +427,12 @@ def add_layer_norm_rms_eps(self, value: float) -> None: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_q_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) + + def add_kv_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) diff --git a/llama.cpp b/llama.cpp index 1da020b288585..560fc7acf6c6c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -288,11 +288,14 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, @@ -305,6 +308,8 @@ enum llm_kv { LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -365,11 +370,14 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, @@ -382,6 +390,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -1803,6 +1813,12 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_leading_dense_layer = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_expert_ff = 0; + uint32_t n_expert_shared = 0; + float f_norm_eps; float f_norm_rms_eps; @@ -1842,6 +1858,12 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_leading_dense_layer != other.n_leading_dense_layer) return true; + if (this->n_lora_q != other.n_lora_q) return true; + if (this->n_lora_kv != other.n_lora_kv) return true; + if (this->n_expert_ff != other.n_expert_ff) return true; + if (this->n_expert_shared != other.n_expert_shared) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -4306,6 +4328,12 @@ static void llm_load_hparams( case LLM_ARCH_DEEPSEEK2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_leading_dense_layer); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_expert_ff); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + model.type = e_model::MODEL_UNKNOWN; } break; default: (void)0; @@ -6107,16 +6135,12 @@ static bool llm_load_tensors( } break; case LLM_ARCH_DEEPSEEK2: { - // TODO maybe move some of these to hparams - const uint32_t n_shared_experts = 2; - const uint32_t moe_intermediate_size = 1536; - const uint32_t q_lora_rank = 1536; - const uint32_t kv_lora_rank = 512; - const uint32_t first_k_dense_replace = 1; - // kept original names of these parameters from HF transformers code for clarity const uint32_t qk_rope_head_dim = hparams.n_rot; const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t q_lora_rank = hparams.n_lora_q; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const uint32_t moe_intermediate_size = hparams.n_expert_ff; model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6144,7 +6168,7 @@ static bool llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - if ((uint32_t) i < first_k_dense_replace) { + if ((uint32_t) i < hparams.n_leading_dense_layer) { layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); @@ -6160,9 +6184,9 @@ static bool llm_load_tensors( layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert}); // Shared expert branch - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { moe_intermediate_size * n_shared_experts, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts}); + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, moe_intermediate_size * hparams.n_expert_shared}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { moe_intermediate_size * hparams.n_expert_shared, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, moe_intermediate_size * hparams.n_expert_shared}); } } } break; @@ -10893,13 +10917,10 @@ struct llm_build_context { // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; - // TODO maybe move some of these to hparams - const uint32_t first_k_dense_replace = 1; - const uint32_t kv_lora_rank = 512; - // kept original names of these parameters from HF transformers code for clarity const uint32_t qk_rope_head_dim = hparams.n_rot; const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -11022,7 +11043,7 @@ struct llm_build_context { struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - if ((uint32_t) il < first_k_dense_replace) { + if ((uint32_t) il < hparams.n_leading_dense_layer) { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); From 79f841778f23bb1d1abec3cb16d35d591ae6a558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 18 May 2024 16:05:16 +0200 Subject: [PATCH 04/27] Added initial support for DeepSeek-V2-Lite model. Added missing scaling of kq_scale parameter. --- convert-hf-to-gguf.py | 3 ++- gguf-py/gguf/constants.py | 1 + llama.cpp | 57 +++++++++++++++++++++++++++------------ 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index cb1f01549e196..2f962dd206899 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2404,7 +2404,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length(hparams["v_head_dim"]) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9b6e568475046..4c8280ade7d11 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -768,6 +768,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_A, MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, diff --git a/llama.cpp b/llama.cpp index 560fc7acf6c6c..1cba84126a4d7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1079,6 +1079,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, @@ -1825,6 +1826,8 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + // TODO read from the model file + float mscale_all_dim = 0.707; // for State Space Models uint32_t ssm_d_conv = 0; @@ -4327,9 +4330,11 @@ static void llm_load_hparams( } break; case LLM_ARCH_DEEPSEEK2: { + bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_leading_dense_layer); - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + if (!is_lite) + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_expert_ff); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); @@ -6135,6 +6140,8 @@ static bool llm_load_tensors( } break; case LLM_ARCH_DEEPSEEK2: { + bool is_lite = (hparams.n_layer == 27); + // kept original names of these parameters from HF transformers code for clarity const uint32_t qk_rope_head_dim = hparams.n_rot; const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; @@ -6157,11 +6164,16 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + if (!is_lite) + layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); - layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); - layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k}); + if (!is_lite) { + layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); + layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k}); + } else { + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + } layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}); layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (qk_nope_head_dim + hparams.n_embd_head_v)}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd}); @@ -10917,6 +10929,11 @@ struct llm_build_context { // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; + bool is_lite = (hparams.n_layer == 27); + + const float mscale = hparams.mscale_all_dim * 1.0f + 0.1f * logf(1.0f / freq_scale); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + // kept original names of these parameters from HF transformers code for clarity const uint32_t qk_rope_head_dim = hparams.n_rot; const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; @@ -10945,18 +10962,24 @@ struct llm_build_context { // self_attention { - // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} - struct ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); - cb(q, "q", il); - - q = llm_build_norm(ctx0, q, hparams, - model.layers[il].attn_q_a_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(q, "q", il); - - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} - q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); - cb(q, "q", il); + struct ggml_tensor * q = NULL; + if (!is_lite) { + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = llm_build_norm(ctx0, q, hparams, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } // split into {n_head * qk_nope_head_dim, n_tokens} struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0); @@ -11029,7 +11052,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(hparams.n_embd_head_k)), cb, il); + key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } if (il == n_layer - 1) { From 6050941653a25d87a50a4fe63e3e04e8f1051a28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 18 May 2024 22:24:26 +0200 Subject: [PATCH 05/27] Corrected mscale calculation. --- llama.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1cba84126a4d7..5a8427f17d8e8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10931,7 +10931,7 @@ struct llm_build_context { bool is_lite = (hparams.n_layer == 27); - const float mscale = hparams.mscale_all_dim * 1.0f + 0.1f * logf(1.0f / freq_scale); + const float mscale = 1.0f + 0.1f * hparams.mscale_all_dim * logf(1.0f / freq_scale); const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); // kept original names of these parameters from HF transformers code for clarity @@ -11107,10 +11107,8 @@ struct llm_build_context { LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(ffn_shexp, "ffn_shexp", il); - moe_out = ggml_add(ctx0, moe_out, ffn_shexp); - cb(moe_out, "ffn_out", il); - - cur = moe_out; + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); } } From 7e4786bbfbbbb607df2e349e7aaa067876510489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 19 May 2024 12:19:40 +0200 Subject: [PATCH 06/27] Added expert_weights_scale parameter for scaling MoE gate weights. --- convert-hf-to-gguf.py | 1 + gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ llama.cpp | 18 +++++++++++++++++- 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 2f962dd206899..d03f4d9e4c772 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2412,6 +2412,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4c8280ade7d11..1739f14d6c720 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -45,6 +45,7 @@ class LLM: EXPERT_COUNT = "{arch}.expert_count" EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e82f4e9ab1a06..da6e686a42a48 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -418,6 +418,9 @@ def add_expert_used_count(self, count: int) -> None: def add_expert_shared_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) + def add_expert_weights_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/llama.cpp b/llama.cpp index 5a8427f17d8e8..ac76d4c585903 100644 --- a/llama.cpp +++ b/llama.cpp @@ -296,6 +296,7 @@ enum llm_kv { LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, @@ -378,6 +379,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, @@ -1819,6 +1821,7 @@ struct llama_hparams { uint32_t n_lora_kv = 0; uint32_t n_expert_ff = 0; uint32_t n_expert_shared = 0; + float expert_weights_scale = 0.0; float f_norm_eps; float f_norm_rms_eps; @@ -1881,6 +1884,7 @@ struct llama_hparams { if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale,EPSILON)) return true; return false; } @@ -4338,6 +4342,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_expert_ff); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); model.type = e_model::MODEL_UNKNOWN; } break; @@ -6659,6 +6664,8 @@ static struct ggml_tensor * llm_build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, + bool scale_w, + float w_scale, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -6690,6 +6697,10 @@ static struct ggml_tensor * llm_build_moe_ffn( weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } + if (scale_w) { + weights = ggml_scale(ctx, weights, w_scale); + cb(weights, "ffn_moe_weights_scaled", il); + } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] @@ -7306,6 +7317,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); } @@ -7787,6 +7799,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_GELU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -7930,6 +7943,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -9275,6 +9289,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, false, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -11093,7 +11108,8 @@ struct llm_build_context { model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, n_expert, n_expert_used, - LLM_FFN_SILU, true, + LLM_FFN_SILU, false, + true, hparams.expert_weights_scale, cb, il); cb(moe_out, "ffn_moe_out", il); From 71a742256ca032e4b26440402fa64fda0307ace5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 19 May 2024 12:22:54 +0200 Subject: [PATCH 07/27] Temporarily hard-coded mscale value for DeepSeek-V2 (FIXME!). --- ggml.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml.c b/ggml.c index 55152bce49ebe..2618edf5996a2 100644 --- a/ggml.c +++ b/ggml.c @@ -14073,6 +14073,8 @@ static void rope_yarn( // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); } + // TODO ugly hack for DeepSeek-V2 until we find a solution + mscale = 1.0; *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } From f99df46f982dba25cb250a54f55e5565a108694a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 19 May 2024 19:59:03 +0200 Subject: [PATCH 08/27] Replaced hardcoded mscale value with rescaling attn_factor that results in the final mscale value equal to 1.0. --- ggml.c | 2 -- llama.cpp | 12 ++++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ggml.c b/ggml.c index 2618edf5996a2..55152bce49ebe 100644 --- a/ggml.c +++ b/ggml.c @@ -14073,8 +14073,6 @@ static void rope_yarn( // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - // TODO ugly hack for DeepSeek-V2 until we find a solution - mscale = 1.0; *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } diff --git a/llama.cpp b/llama.cpp index ac76d4c585903..19b30ba5608cd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10949,6 +10949,14 @@ struct llm_build_context { const float mscale = 1.0f + 0.1f * hparams.mscale_all_dim * logf(1.0f / freq_scale); const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + // DeepSeek-V2 uses non-standard YaRN mscale calculation from mscale and mscale_all_dim + // config.json parameters. However, both of these are equal to 0.707 in released models, + // which results in the final mscale value equal to 1.0. To get the same value we + // pre-scale the attn_factor. + // TODO Get rid of this when other models start using DeepSeek-V2 + // variant of mscale calculation resulting in the API change. + const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + // kept original names of these parameters from HF transformers code for clarity const uint32_t qk_rope_head_dim = hparams.n_rot; const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; @@ -11040,7 +11048,7 @@ struct llm_build_context { q_pe = ggml_rope_custom( ctx0, q_pe, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow + ext_factor, attn_factor_scaled, beta_fast, beta_slow ); cb(q_pe, "q_pe", il); @@ -11048,7 +11056,7 @@ struct llm_build_context { k_pe = ggml_rope_custom( ctx0, ggml_view_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow + ext_factor, attn_factor_scaled, beta_fast, beta_slow ); cb(k_pe, "k_pe", il); From 3ae7235e9419085ec47dab72d38f8dcae9dd7e27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 19 May 2024 20:14:14 +0200 Subject: [PATCH 09/27] Whitespace formatting fixes. --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 19b30ba5608cd..c2799ecb7aa18 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1884,7 +1884,7 @@ struct llama_hparams { if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; - if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale,EPSILON)) return true; + if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; return false; } From 68a5103026cf3b6a3b9de60f9af0568ce8719468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 20 May 2024 17:20:18 +0200 Subject: [PATCH 10/27] Referenced the relevant GitHub discussion instead of providing long comments. --- llama.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index c2799ecb7aa18..f563c52beada6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10946,15 +10946,10 @@ struct llm_build_context { bool is_lite = (hparams.n_layer == 27); + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. const float mscale = 1.0f + 0.1f * hparams.mscale_all_dim * logf(1.0f / freq_scale); const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); - - // DeepSeek-V2 uses non-standard YaRN mscale calculation from mscale and mscale_all_dim - // config.json parameters. However, both of these are equal to 0.707 in released models, - // which results in the final mscale value equal to 1.0. To get the same value we - // pre-scale the attn_factor. - // TODO Get rid of this when other models start using DeepSeek-V2 - // variant of mscale calculation resulting in the API change. const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); // kept original names of these parameters from HF transformers code for clarity From 7be56da99a903045bf1f29d93e7dfec7ab097f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 20 May 2024 18:51:23 +0200 Subject: [PATCH 11/27] Added YaRN log multiplier model header parameter corresponding to the multiplier of the ln(s) from the sqrt(1/t) = 0.1 ln(s) + 1 equation. --- convert-hf-to-gguf.py | 1 + gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ llama.cpp | 9 ++++++--- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d03f4d9e4c772..b9f893cac5d8d 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2420,6 +2420,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1*hparams["rope_scaling"]["mscale_all_dim"]) _experts: list[dict[str, Tensor]] | None = None diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1739f14d6c720..1a579dc80d593 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -69,6 +69,7 @@ class Rope: SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index da6e686a42a48..c834efd7fc82b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -457,6 +457,9 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None: def add_rope_scaling_finetuned(self, value: bool) -> None: self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + def add_rope_scaling_yarn_log_mul(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value) + def add_ssm_conv_kernel(self, value: int) -> None: self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value) diff --git a/llama.cpp b/llama.cpp index f563c52beada6..d081f08204268 100644 --- a/llama.cpp +++ b/llama.cpp @@ -319,6 +319,7 @@ enum llm_kv { LLM_KV_ROPE_SCALING_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, LLM_KV_SPLIT_NO, LLM_KV_SPLIT_COUNT, @@ -402,6 +403,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -1829,8 +1831,7 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; - // TODO read from the model file - float mscale_all_dim = 0.707; + float rope_yarn_log_mul; // for State Space Models uint32_t ssm_d_conv = 0; @@ -1885,6 +1886,7 @@ struct llama_hparams { if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; + if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true; return false; } @@ -4343,6 +4345,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_expert_ff); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); model.type = e_model::MODEL_UNKNOWN; } break; @@ -10948,7 +10951,7 @@ struct llm_build_context { // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float mscale = 1.0f + 0.1f * hparams.mscale_all_dim * logf(1.0f / freq_scale); + const float mscale = 1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale); const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); From 842ff3fed17e416d74227c39c9858fb14653f998 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 21 May 2024 17:12:14 +0200 Subject: [PATCH 12/27] Added 16B and 236B model types for DeepSeek-V2. --- llama.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index d081f08204268..b13290b048212 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1778,6 +1778,7 @@ enum e_model { MODEL_13B, MODEL_14B, MODEL_15B, + MODEL_16B, MODEL_20B, MODEL_30B, MODEL_34B, @@ -1785,6 +1786,7 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_236B, MODEL_314B, MODEL_SMALL, MODEL_MEDIUM, @@ -3837,6 +3839,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; case MODEL_15B: return "15B"; + case MODEL_16B: return "16B"; case MODEL_20B: return "20B"; case MODEL_30B: return "30B"; case MODEL_34B: return "34B"; @@ -3844,6 +3847,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; + case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; @@ -4347,7 +4351,11 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); - model.type = e_model::MODEL_UNKNOWN; + switch (hparams.n_layer) { + case 27: model.type = e_model::MODEL_16B; break; + case 60: model.type = e_model::MODEL_236B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } } break; default: (void)0; } From c033958d7c0055f89f0dcf738054342935c09652 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 21 May 2024 17:53:37 +0200 Subject: [PATCH 13/27] Removed usage of output bias tensor since it's not present in DeepSeek-V2 models. --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index b13290b048212..03bf77fa0c85e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11080,7 +11080,7 @@ struct llm_build_context { key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * qk_nope_head_dim); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, NULL, key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } From bb9c361802dc6546c32663d4910e770d7cc9e7bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 24 May 2024 16:37:29 +0200 Subject: [PATCH 14/27] gguf-py : re-add SCALING_YARN_LOG_MUL removed during merge by accident --- gguf-py/gguf/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9ed7fdf921fbe..e5e0c6b49c043 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -70,6 +70,7 @@ class Rope: SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" From f3b5e7d436e2bad1fda389fde4f13bd777b001c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 26 May 2024 20:36:28 +0200 Subject: [PATCH 15/27] llama : correct llm_build_moe_ffn() arguments in build_arctic() --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 84f884114de75..a3ee1d4c1f622 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11142,7 +11142,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, - true, hparams.expert_weights_scale, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); From abef8b26349a349cabe46d121961ade9aaef1e10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 12:47:53 +0200 Subject: [PATCH 16/27] llama : code style corrections --- llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index a3ee1d4c1f622..5f51275611189 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4466,8 +4466,9 @@ static void llm_load_hparams( bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_leading_dense_layer); - if (!is_lite) + if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_expert_ff); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); @@ -6330,8 +6331,9 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - if (!is_lite) + if (!is_lite) { layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + } layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); if (!is_lite) { From a654cd992b911da140439cd13f297d24a204bf3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 12:54:47 +0200 Subject: [PATCH 17/27] llama : rename n_expert_ff to n_ff_exp --- llama.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 5f51275611189..eeb2939f5279d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1839,7 +1839,7 @@ struct llama_hparams { uint32_t n_leading_dense_layer = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; - uint32_t n_expert_ff = 0; + uint32_t n_ff_exp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; @@ -1887,7 +1887,7 @@ struct llama_hparams { if (this->n_leading_dense_layer != other.n_leading_dense_layer) return true; if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_kv != other.n_lora_kv) return true; - if (this->n_expert_ff != other.n_expert_ff) return true; + if (this->n_ff_exp != other.n_ff_exp) return true; if (this->n_expert_shared != other.n_expert_shared) return true; if (this->rope_finetuned != other.rope_finetuned) return true; @@ -4470,7 +4470,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_expert_ff); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); @@ -6314,7 +6314,7 @@ static bool llm_load_tensors( const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; const uint32_t q_lora_rank = hparams.n_lora_q; const uint32_t kv_lora_rank = hparams.n_lora_kv; - const uint32_t moe_intermediate_size = hparams.n_expert_ff; + const uint32_t moe_intermediate_size = hparams.n_ff_exp; model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); From 5a3e6b6cd130ed60b950932ac8db6a7cd12ecac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 13:09:06 +0200 Subject: [PATCH 18/27] llama : rename qk_rope_head_dim, qk_nope_head_dim variables to n_embd_head_qk_rope, n_embd_head_qk_nope --- llama.cpp | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/llama.cpp b/llama.cpp index eeb2939f5279d..27380826c1908 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6309,9 +6309,8 @@ static bool llm_load_tensors( { bool is_lite = (hparams.n_layer == 27); - // kept original names of these parameters from HF transformers code for clarity - const uint32_t qk_rope_head_dim = hparams.n_rot; - const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t q_lora_rank = hparams.n_lora_q; const uint32_t kv_lora_rank = hparams.n_lora_kv; const uint32_t moe_intermediate_size = hparams.n_ff_exp; @@ -6342,8 +6341,8 @@ static bool llm_load_tensors( } else { layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); } - layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}); - layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (qk_nope_head_dim + hparams.n_embd_head_v)}); + layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}); + layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); @@ -11191,9 +11190,8 @@ struct llm_build_context { const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - // kept original names of these parameters from HF transformers code for clarity - const uint32_t qk_rope_head_dim = hparams.n_rot; - const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; struct ggml_tensor * cur; @@ -11238,22 +11236,22 @@ struct llm_build_context { cb(q, "q", il); } - // split into {n_head * qk_nope_head_dim, n_tokens} - struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0); + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0); cb(q_nope, "q_nope", il); - // and {n_head * qk_rope_head_dim, n_tokens} - struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, qk_rope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * qk_nope_head_dim); + // and {n_head * n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope); cb(q_pe, "q_pe", il); - // {n_embd, kv_lora_rank + qk_rope_head_dim} * {n_embd, n_tokens} -> {kv_lora_rank + qk_rope_head_dim, n_tokens} + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); cb(compressed_kv_pe, "compressed_kv_pe", il); // split into {kv_lora_rank, n_tokens} struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0); cb(compressed_kv, "compressed_kv", il); - // and {qk_rope_head_dim, n_tokens} - struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, qk_rope_head_dim, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank); + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank); cb(k_pe, "k_pe", il); compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams, @@ -11261,16 +11259,16 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(compressed_kv, "compressed_kv", il); - // {kv_lora_rank, n_head * (qk_nope_head_dim + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (qk_nope_head_dim + n_embd_head_v), n_tokens} + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv); cb(kv, "kv", il); - // split into {n_head * qk_nope_head_dim, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), 0); + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * value_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * qk_nope_head_dim); + struct ggml_tensor * value_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope); cb(value_states, "value_states", il); value_states = ggml_cont(ctx0, value_states); @@ -11288,7 +11286,7 @@ struct llm_build_context { // shared RoPE key k_pe = ggml_rope_ext( - ctx0, ggml_view_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr, + ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow ); @@ -11297,7 +11295,7 @@ struct llm_build_context { struct ggml_tensor * query_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); cb(query_states, "query_states", il); query_states = ggml_set_inplace(ctx0, query_states, q_nope, query_states->nb[1], query_states->nb[2], query_states->nb[3], 0); - query_states = ggml_set_inplace(ctx0, query_states, q_pe, query_states->nb[1], query_states->nb[2], query_states->nb[3], ggml_element_size(query_states) * qk_nope_head_dim); + query_states = ggml_set_inplace(ctx0, query_states, q_pe, query_states->nb[1], query_states->nb[2], query_states->nb[3], ggml_element_size(query_states) * n_embd_head_qk_nope); k_pe = ggml_repeat(ctx0, k_pe, q_pe); cb(k_pe, "k_pe", il); @@ -11305,7 +11303,7 @@ struct llm_build_context { struct ggml_tensor * key_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); cb(key_states, "key_states", il); key_states = ggml_set_inplace(ctx0, key_states, k_nope, key_states->nb[1], key_states->nb[2], key_states->nb[3], 0); - key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * qk_nope_head_dim); + key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * n_embd_head_qk_nope); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, From 20769c0f7ff567edd82b48cacabce40b93712c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 13:11:31 +0200 Subject: [PATCH 19/27] llama : remove trailing whitespaces --- llama.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 27380826c1908..6cc1fe59616ca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11243,7 +11243,7 @@ struct llm_build_context { struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope); cb(q_pe, "q_pe", il); - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); cb(compressed_kv_pe, "compressed_kv_pe", il); @@ -11340,7 +11340,7 @@ struct llm_build_context { model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - + ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, cur, model.layers[il].ffn_gate_inp, @@ -11352,7 +11352,7 @@ struct llm_build_context { true, hparams.expert_weights_scale, cb, il); cb(moe_out, "ffn_moe_out", il); - + // FFN shared expert { ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, @@ -11362,7 +11362,7 @@ struct llm_build_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(ffn_shexp, "ffn_shexp", il); - + cur = ggml_add(ctx0, moe_out, ffn_shexp); cb(cur, "ffn_out", il); } From fac1e804a1e2d49af73c9aa3a1e7893503d6dee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 13:17:49 +0200 Subject: [PATCH 20/27] llama : rename moe_intermediate_size variable to n_ff_exp --- llama.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 6cc1fe59616ca..0e311f1f0603f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6313,7 +6313,7 @@ static bool llm_load_tensors( const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t q_lora_rank = hparams.n_lora_q; const uint32_t kv_lora_rank = hparams.n_lora_kv; - const uint32_t moe_intermediate_size = hparams.n_ff_exp; + const uint32_t n_ff_exp = hparams.n_ff_exp; model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6358,14 +6358,14 @@ static bool llm_load_tensors( GGML_ASSERT(hparams.n_expert_used > 0); // MoE branch - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert}); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {moe_intermediate_size, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert}); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); // Shared expert branch - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, moe_intermediate_size * hparams.n_expert_shared}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { moe_intermediate_size * hparams.n_expert_shared, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, moe_intermediate_size * hparams.n_expert_shared}); + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * hparams.n_expert_shared, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); } } } break; From 56f70112eb7a073fa1646eca6097d202dc853656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 13:39:06 +0200 Subject: [PATCH 21/27] llama : rename n_leading_dense_layer to n_layer_dense_lead --- llama.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index 0e311f1f0603f..7587c5d5928b5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1836,7 +1836,7 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types - uint32_t n_leading_dense_layer = 0; + uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; uint32_t n_ff_exp = 0; @@ -1884,11 +1884,11 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; - if (this->n_leading_dense_layer != other.n_leading_dense_layer) return true; - if (this->n_lora_q != other.n_lora_q) return true; - if (this->n_lora_kv != other.n_lora_kv) return true; - if (this->n_ff_exp != other.n_ff_exp) return true; - if (this->n_expert_shared != other.n_expert_shared) return true; + if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; + if (this->n_lora_q != other.n_lora_q) return true; + if (this->n_lora_kv != other.n_lora_kv) return true; + if (this->n_ff_exp != other.n_ff_exp) return true; + if (this->n_expert_shared != other.n_expert_shared) return true; if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -4465,7 +4465,7 @@ static void llm_load_hparams( { bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_leading_dense_layer); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } @@ -6347,7 +6347,7 @@ static bool llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - if ((uint32_t) i < hparams.n_leading_dense_layer) { + if ((uint32_t) i < hparams.n_layer_dense_lead) { layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); @@ -11321,7 +11321,7 @@ struct llm_build_context { struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - if ((uint32_t) il < hparams.n_leading_dense_layer) { + if ((uint32_t) il < hparams.n_layer_dense_lead) { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); From 82cec8b84b462da50857d08b4c1059a9cae98dd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 14:33:31 +0200 Subject: [PATCH 22/27] llama : use attn_factor in mscale calculation to match the rope_yarn() implementation --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 7587c5d5928b5..fcec63cfbc659 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11186,7 +11186,7 @@ struct llm_build_context { // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float mscale = 1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale); + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); From 5cc7ec161c00aebed0b14cdd623166d7f3515d13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 14:42:27 +0200 Subject: [PATCH 23/27] llama : rename query_states, key_states, value_states to q_states, k_states, v_states --- llama.cpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/llama.cpp b/llama.cpp index fcec63cfbc659..d515a4f5a4b59 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11268,14 +11268,14 @@ struct llm_build_context { cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * value_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope); - cb(value_states, "value_states", il); + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope); + cb(v_states, "v_states", il); - value_states = ggml_cont(ctx0, value_states); - cb(value_states, "value_states", il); + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); - value_states = ggml_view_2d(ctx0, value_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0); - cb(value_states, "value_states", il); + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0); + cb(v_states, "v_states", il); q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, @@ -11292,22 +11292,22 @@ struct llm_build_context { ); cb(k_pe, "k_pe", il); - struct ggml_tensor * query_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); - cb(query_states, "query_states", il); - query_states = ggml_set_inplace(ctx0, query_states, q_nope, query_states->nb[1], query_states->nb[2], query_states->nb[3], 0); - query_states = ggml_set_inplace(ctx0, query_states, q_pe, query_states->nb[1], query_states->nb[2], query_states->nb[3], ggml_element_size(query_states) * n_embd_head_qk_nope); + struct ggml_tensor * q_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); + cb(q_states, "q_states", il); + q_states = ggml_set_inplace(ctx0, q_states, q_nope, q_states->nb[1], q_states->nb[2], q_states->nb[3], 0); + q_states = ggml_set_inplace(ctx0, q_states, q_pe, q_states->nb[1], q_states->nb[2], q_states->nb[3], ggml_element_size(q_states) * n_embd_head_qk_nope); k_pe = ggml_repeat(ctx0, k_pe, q_pe); cb(k_pe, "k_pe", il); - struct ggml_tensor * key_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); - cb(key_states, "key_states", il); - key_states = ggml_set_inplace(ctx0, key_states, k_nope, key_states->nb[1], key_states->nb[2], key_states->nb[3], 0); - key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * n_embd_head_qk_nope); + struct ggml_tensor * k_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); + cb(k_states, "k_states", il); + k_states = ggml_set_inplace(ctx0, k_states, k_nope, k_states->nb[1], k_states->nb[2], k_states->nb[3], 0); + k_states = ggml_set_inplace(ctx0, k_states, k_pe, k_states->nb[1], k_states->nb[2], k_states->nb[3], ggml_element_size(k_states) * n_embd_head_qk_nope); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } if (il == n_layer - 1) { From d02130d5499e358868b3d1f2cc21f7b8c68cd8ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 15:30:17 +0200 Subject: [PATCH 24/27] llama : print DeekSeek-V2-specific parameters in llm_load_print_meta() --- llama.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/llama.cpp b/llama.cpp index d515a4f5a4b59..c84f24a409f9b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4989,6 +4989,16 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); } if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } + + if (model.arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + } } // Returns false if cancelled by progress_callback From bde971a9ca60edc98ec7744382436fd67b1022cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 May 2024 18:25:47 +0200 Subject: [PATCH 25/27] convert-hf : fix flake8 Lint errors --- convert-hf-to-gguf.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 880f874b02c4a..4450b949e7b97 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2646,14 +2646,11 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1*hparams["rope_scaling"]["mscale_all_dim"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"]) _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - n_head = self.hparams["num_attention_heads"] - n_kv_head = self.hparams.get("num_key_value_heads") - # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] From 841cd47432affbca30a62aa0f9429a80599074d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 28 May 2024 11:15:17 +0200 Subject: [PATCH 26/27] llama : replace ggml_new_tensor_3d + ggml_set_inplace + ggml_set_inplace with single ggml_concat in build_deepseek2() --- llama.cpp | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index cef5bfdde2454..9c80a62119d44 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11305,18 +11305,11 @@ struct llm_build_context { ); cb(k_pe, "k_pe", il); - struct ggml_tensor * q_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); cb(q_states, "q_states", il); - q_states = ggml_set_inplace(ctx0, q_states, q_nope, q_states->nb[1], q_states->nb[2], q_states->nb[3], 0); - q_states = ggml_set_inplace(ctx0, q_states, q_pe, q_states->nb[1], q_states->nb[2], q_states->nb[3], ggml_element_size(q_states) * n_embd_head_qk_nope); - k_pe = ggml_repeat(ctx0, k_pe, q_pe); - cb(k_pe, "k_pe", il); - - struct ggml_tensor * k_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens); + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - k_states = ggml_set_inplace(ctx0, k_states, k_nope, k_states->nb[1], k_states->nb[2], k_states->nb[3], 0); - k_states = ggml_set_inplace(ctx0, k_states, k_pe, k_states->nb[1], k_states->nb[2], k_states->nb[3], ggml_element_size(k_states) * n_embd_head_qk_nope); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, From 3efb6595ae6c4f5e10492c70172085e6827d30a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 28 May 2024 13:28:57 +0200 Subject: [PATCH 27/27] gguf-py, llama : whitespace formatting fixes --- gguf-py/gguf/constants.py | 28 ++++++++++++++-------------- llama.cpp | 28 ++++++++++++++-------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e5e0c6b49c043..55ec2cb5c848a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -33,21 +33,21 @@ class General: FILE_TYPE = "general.file_type" class LLM: - VOCAB_SIZE = "{arch}.vocab_size" - CONTEXT_LENGTH = "{arch}.context_length" - EMBEDDING_LENGTH = "{arch}.embedding_length" - BLOCK_COUNT = "{arch}.block_count" - LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" - FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + VOCAB_SIZE = "{arch}.vocab_size" + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" - USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" - TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" - EXPERT_COUNT = "{arch}.expert_count" - EXPERT_USED_COUNT = "{arch}.expert_used_count" - EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" - EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" - POOLING_TYPE = "{arch}.pooling_type" - LOGIT_SCALE = "{arch}.logit_scale" + USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + POOLING_TYPE = "{arch}.pooling_type" + LOGIT_SCALE = "{arch}.logit_scale" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/llama.cpp b/llama.cpp index 9c80a62119d44..d4cbe6ddb8abc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -362,21 +362,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, - { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, - { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, - { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, - { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },