Skip to content

Commit

Permalink
support mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Jun 16, 2024
1 parent 4c4e667 commit fbaa188
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .starcoder import Starcoder
from .internlm2 import InternLM2
from .qwen2 import Qwen2
from .mixtral import Mixtral
62 changes: 62 additions & 0 deletions llmc/models/mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from .base_model import BaseModel
from llmc.utils.registry_factory import MODEL_REGISTRY


@MODEL_REGISTRY
class Mixtral(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)

def find_blocks(self):
self.blocks = self.model.model.layers

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens

def find_block_name(self):
self.block_name_prefix = "model.layers"

def get_embed_layers(self):
return [self.embed_tokens]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

def has_bias(self):
return False

def get_layernorms_in_block(self, block):
return {
"input_layernorm": block.input_layernorm,
"post_attention_layernorm": block.post_attention_layernorm,
}

def get_subsets_in_block(self, block):
return [
{
"layers": {
"self_attn.q_proj": block.self_attn.q_proj,
"self_attn.k_proj": block.self_attn.k_proj,
"self_attn.v_proj": block.self_attn.v_proj,
},
"prev_op": [block.input_layernorm],
"input": ["self_attn.q_proj"],
"inspect": block.self_attn,
"has_kwargs": True,
},
{
"layers": {"self_attn.o_proj": block.self_attn.o_proj},
"prev_op": [block.self_attn.v_proj],
"input": ["self_attn.o_proj"],
"inspect": block.self_attn.o_proj,
"has_kwargs": False,
},
{
"layers": {"block_sparse_moe.gate": block.block_sparse_moe.gate},
"prev_op": [block.post_attention_layernorm],
"input": ["block_sparse_moe.gate"],
"inspect": block.block_sparse_moe,
"has_kwargs": True,
}
# Moe layers can not transfrom.
]

0 comments on commit fbaa188

Please sign in to comment.