Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Jun 16, 2024
1 parent fbaa188 commit f4d48fa
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
6 changes: 6 additions & 0 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def __init__(self, weight, eps=1e-6):
def __repr__(self):
return f"LlmcQwen2RMSNorm()"

class LlmcMixtralRMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)

def __repr__(self):
return f"LlmcMixtralRMSNorm()"

class LlmcMistralRMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
Expand Down
9 changes: 6 additions & 3 deletions llmc/compression/quantization/omniq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
LlmcLayerNorm,
LlmcLlamaRMSNorm,
LlmcMistralRMSNorm,
LlmcQwen2RMSNorm
LlmcQwen2RMSNorm,
LlmcMixtralRMSNorm
)
from .train_utils import NativeScalerWithGradNormCount, TruncateFunction, LossFunction
from llmc.utils.registry_factory import ALGO_REGISTRY
Expand Down Expand Up @@ -415,13 +416,15 @@ def replace_layer_norms(self, block, idx):
self.model.replace_module_block(LlmcLlamaRMSNorm, block, idx, {})
elif self.config["model"]["type"] == "Qwen2":
self.model.replace_module_block(LlmcQwen2RMSNorm, block, idx, {})
elif self.config["model"]["type"] == "Mixtral":
self.model.replace_module_block(LlmcMixtralRMSNorm, block, idx, {})
else:
self.model.replace_module_block(LlmcLayerNorm, block, idx, {})

def get_layer_norms(self, block):
layer_norms = []
for n, m in block.named_modules():
if isinstance(m, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm)):
if isinstance(m, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm, LlmcMixtralRMSNorm)):
layer_norms.append(m)
return layer_norms

Expand Down Expand Up @@ -655,7 +658,7 @@ def smooth_q_k_inplace(self, block):

for name, module in block.named_modules():
if isinstance(
module, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm)
module, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm, LlmcMixtralRMSNorm)
):
module.use_tmp_parameter = False

Expand Down
5 changes: 3 additions & 2 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
LlmcLayerNorm,
LlmcLlamaRMSNorm,
LlmcMistralRMSNorm,
LlmcQwen2RMSNorm
LlmcQwen2RMSNorm,
LlmcMixtralRMSNorm
)


Expand Down Expand Up @@ -144,7 +145,7 @@ def replace_module_all(self, module, params_dict):
logger.info(f"The Replaced model: {self.model}")

def replace_module_block(self, module, block, i, params_dict):
if module in [LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm]:
if module in [LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm, LlmcMixtralRMSNorm]:
layer_norms = self.get_layernorms_in_block(block)
self.replace_module_layernorm(module, block, layer_norms, i, params_dict)
else:
Expand Down

0 comments on commit f4d48fa

Please sign in to comment.