Skip to content

Commit

Permalink
Add Qwen2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Jun 11, 2024
1 parent db989c3 commit 10cc471
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 7 deletions.
8 changes: 8 additions & 0 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def __repr__(self):
return f"LlmcLlamaRMSNorm()"


class LlmcQwen2RMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)

def __repr__(self):
return f"LlmcQwen2RMSNorm()"


class LlmcMistralRMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)
Expand Down
11 changes: 7 additions & 4 deletions llmc/compression/quantization/omniq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
LlmcLayerNorm,
LlmcLlamaRMSNorm,
LlmcMistralRMSNorm,
LlmcQwen2RMSNorm
)
from .train_utils import NativeScalerWithGradNormCount, TruncateFunction, LossFunction
from llmc.utils.registry_factory import ALGO_REGISTRY
Expand All @@ -31,11 +32,11 @@ def __init__(self, model, quant_config, input, config):
self.add_quant_config()

if (
self.config["model"]["type"] not in ["Llama", "Opt", "Falcon", "Mistral"]
self.config["model"]["type"] not in ["Llama", "Opt", "Falcon", "Mistral", "Qwen2"]
and self.let
):
raise ValueError("Only support for opt/llama/Llama-2/falcon/Mistral now")
elif self.config["model"]["type"] in ("Llama", "Mistral"):
elif self.config["model"]["type"] in ("Llama", "Mistral", "Qwen2"):
self.attention_mask = self.input["kwargs"][0]["attention_mask"]
self.position_ids = self.input["kwargs"][0]["position_ids"]
else:
Expand Down Expand Up @@ -412,13 +413,15 @@ def replace_layer_norms(self, block, idx):
self.model.replace_module_block(LlmcMistralRMSNorm, block, idx, {})
elif self.config["model"]["type"] == "Llama":
self.model.replace_module_block(LlmcLlamaRMSNorm, block, idx, {})
elif self.config["model"]["type"] == "Qwen2":
self.model.replace_module_block(LlmcQwen2RMSNorm, block, idx, {})
else:
self.model.replace_module_block(LlmcLayerNorm, block, idx, {})

def get_layer_norms(self, block):
layer_norms = []
for n, m in block.named_modules():
if isinstance(m, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm)):
if isinstance(m, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm)):
layer_norms.append(m)
return layer_norms

Expand Down Expand Up @@ -652,7 +655,7 @@ def smooth_q_k_inplace(self, block):

for name, module in block.named_modules():
if isinstance(
module, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm)
module, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm)
):
module.use_tmp_parameter = False

Expand Down
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .mistral import Mistral
from .starcoder import Starcoder
from .internlm2 import InternLM2
from .qwen2 import Qwen2
5 changes: 2 additions & 3 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
LlmcLayerNorm,
LlmcLlamaRMSNorm,
LlmcMistralRMSNorm,
LlmcQwen2RMSNorm
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import MistralRMSNorm


class BaseModel(metaclass=ABCMeta):
Expand Down Expand Up @@ -145,7 +144,7 @@ def replace_module_all(self, module, params_dict):
logger.info(f"The Replaced model: {self.model}")

def replace_module_block(self, module, block, i, params_dict):
if module in [LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm]:
if module in [LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm]:
layer_norms = self.get_layernorms_in_block(block)
self.replace_module_layernorm(module, block, layer_norms, i, params_dict)
else:
Expand Down
72 changes: 72 additions & 0 deletions llmc/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from .base_model import BaseModel
from llmc.utils.registry_factory import MODEL_REGISTRY


@MODEL_REGISTRY
class Qwen2(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)

def find_blocks(self):
self.blocks = self.model.model.layers

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens

def find_block_name(self):
self.block_name_prefix = "model.layers"
self.pairs = {"q_proj": "qkv", "o_proj": "out", "up_proj": "fc1"}

def get_embed_layers(self):
return [self.embed_tokens]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

def has_bias(self):
return False

def get_layernorms_in_block(self, block):
return {
"input_layernorm": block.input_layernorm,
"post_attention_layernorm": block.post_attention_layernorm,
}

def get_subsets_in_block(self, block):
return [
{
"layers": {
"self_attn.q_proj": block.self_attn.q_proj,
"self_attn.k_proj": block.self_attn.k_proj,
"self_attn.v_proj": block.self_attn.v_proj,
},
"prev_op": [block.input_layernorm],
"input": ["self_attn.q_proj"],
"inspect": block.self_attn,
"has_kwargs": True,
},
{
"layers": {"self_attn.o_proj": block.self_attn.o_proj},
"prev_op": [block.self_attn.v_proj],
"input": ["self_attn.o_proj"],
"inspect": block.self_attn.o_proj,
"has_kwargs": False,
},
{
"layers": {
"mlp.gate_proj": block.mlp.gate_proj,
"mlp.up_proj": block.mlp.up_proj,
},
"prev_op": [block.post_attention_layernorm],
"input": ["mlp.gate_proj"],
"inspect": block.mlp,
"has_kwargs": False,
},
{
"layers": {"mlp.down_proj": block.mlp.down_proj},
"prev_op": [block.mlp.up_proj],
"input": ["mlp.down_proj"],
"inspect": block.mlp.down_proj,
"has_kwargs": False,
},
]

0 comments on commit 10cc471

Please sign in to comment.