Skip to content

Commit

Permalink
Add support for DeepseekV2ForCausalLM (#7519)
Browse files Browse the repository at this point in the history
* common : increase max number of experts to 160

* common : add tensors ATTN_Q_A, ATTN_Q_A_NORM, ATTN_Q_B, ATTN_KV_A_MQA, ATTN_KV_A_NORM, ATTN_KV_B needed by DeepSeek-V2 MLA (multi-head latent attention) architecture

* common : add model header parameters: leading_dense_block_count, expert_feed_forward_length, expert_shared_count, expert_weights_scale, attention.q_lora_rank, attention.kv_lora_rank, rope.scaling.yarn_log_multiplier

* convert-hf : add model conversion support for DeepseekV2ForCausalLM

* llama : add model types for DeepSeek-V2 and DeepSeek-V2-Lite models

* llama : add two new llm_build_moe_ffn() arguments: scale_w (whether to scale weights of selected MoE experts) and w_scale (numerical value of the scaling factor)

* llama : add inference support for LLM_ARCH_DEEPSEEK2

---------

Co-authored-by: Stanisław Szymczyk <[email protected]>
  • Loading branch information
fairydreaming and sszymczy committed May 28, 2024
1 parent edc2943 commit ee3dff6
Show file tree
Hide file tree
Showing 5 changed files with 599 additions and 26 deletions.
79 changes: 79 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2620,6 +2620,85 @@ def write_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("DeepseekV2ForCausalLM")
class DeepseekV2Model(Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams

self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["v_head_dim"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"])

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def write_tensors(self):
super().write_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


###### CONVERSION LOGIC ######


Expand Down
74 changes: 63 additions & 11 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,21 @@ class General:
FILE_TYPE = "general.file_type"

class LLM:
VOCAB_SIZE = "{arch}.vocab_size"
CONTEXT_LENGTH = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length"
BLOCK_COUNT = "{arch}.block_count"
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
VOCAB_SIZE = "{arch}.vocab_size"
CONTEXT_LENGTH = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length"
BLOCK_COUNT = "{arch}.block_count"
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand All @@ -55,6 +59,8 @@ class Attention:
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand All @@ -64,6 +70,7 @@ class Rope:
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"

class SSM:
CONV_KERNEL = "{arch}.ssm.conv_kernel"
Expand Down Expand Up @@ -140,6 +147,7 @@ class MODEL_ARCH(IntEnum):
DBRX = auto()
OLMO = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -185,6 +193,12 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto()
SSM_D = auto()
SSM_OUT = auto()
ATTN_Q_A = auto()
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -221,6 +235,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -266,6 +281,12 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -757,6 +778,33 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.DEEPSEEK2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO
}

Expand Down Expand Up @@ -790,6 +838,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.DEEPSEEK2: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}

#
Expand Down
21 changes: 21 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,15 @@ def add_embedding_length(self, length: int) -> None:
def add_block_count(self, length: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)

def add_leading_dense_block_count(self, length: int) -> None:
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)

def add_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_expert_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_parallel_residual(self, use: bool) -> None:
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)

Expand Down Expand Up @@ -407,6 +413,12 @@ def add_expert_count(self, count: int) -> None:
def add_expert_used_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)

def add_expert_shared_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)

def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)

def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

Expand All @@ -416,6 +428,12 @@ def add_layer_norm_rms_eps(self, value: float) -> None:
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

def add_q_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)

def add_kv_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand All @@ -440,6 +458,9 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)

def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)

def add_ssm_conv_kernel(self, value: int) -> None:
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)

Expand Down
29 changes: 28 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_UP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
),

# AWQ-activation gate
Expand Down Expand Up @@ -285,6 +286,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_GATE_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
),

# Feed-forward down
Expand Down Expand Up @@ -320,6 +322,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_DOWN_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_NORM: (
Expand Down Expand Up @@ -383,6 +386,30 @@ class TensorNameMap:
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
),

MODEL_TENSOR.ATTN_Q_A: (
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_B: (
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_KV_A_MQA: (
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
),

MODEL_TENSOR.ATTN_KV_B: (
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),

MODEL_TENSOR.ATTN_KV_A_NORM: (
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
),
}

# architecture-specific block mappings
Expand Down Expand Up @@ -415,7 +442,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
if tensor not in MODEL_TENSORS[arch]:
continue
# TODO: make this configurable
n_experts = 128
n_experts = 160
for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name)
Expand Down

0 comments on commit ee3dff6

Please sign in to comment.