Skip to content

Commit

Permalink
Update clip-sym
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Jun 4, 2024
1 parent 4eca2e7 commit b824375
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 28 deletions.
6 changes: 5 additions & 1 deletion llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def __init__(self, model, quant_config, input, config):
self.trans_version = self.quant_config["special"]["trans_version"]
else:
self.trans_version = "v2"
if "special" in self.quant_config and "weight_clip" in self.quant_config["special"]:
if (
"special" in self.quant_config
and "weight_clip" in self.quant_config["special"]
):
self.weight_clip = self.quant_config["special"]["weight_clip"]
else:
self.weight_clip = True
Expand Down Expand Up @@ -183,6 +186,7 @@ def subset_transform(
):
logger.info("Cannot apply scale. Do not transform this subset.")
return

scale = self.search_scale_subset(
layers, input_feat[input_name], inspect_module, subset_kwargs
)
Expand Down
46 changes: 34 additions & 12 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,50 @@ def set_quant_config(self):
self.w_only = True

# set special quant config
if "special" in self.quant_config and "weight_clip" in self.quant_config["special"]:
if (
"special" in self.quant_config
and "weight_clip" in self.quant_config["special"]
):
self.weight_clip = self.quant_config["special"]["weight_clip"]
else:
self.weight_clip = True

if "special" in self.quant_config and "save_scale" in self.quant_config["special"]:
if (
"special" in self.quant_config
and "save_scale" in self.quant_config["special"]
):
self.save_scale = self.quant_config["special"]["save_scale"]
self.scale_path = self.quant_config["special"]["scale_path"]
self.act_scales = {}
else:
self.save_scale = False

if "special" in self.quant_config and "save_clip" in self.quant_config["special"]:
if (
"special" in self.quant_config
and "save_clip" in self.quant_config["special"]
):
self.save_clip = self.quant_config["special"]["save_clip"]
self.clip_path = self.quant_config["special"]["clip_path"]
self.weight_clips = {}
else:
self.save_clip = False

if "special" in self.quant_config and "clip_version" in self.quant_config["special"]:
if (
"special" in self.quant_config
and "clip_version" in self.quant_config["special"]
):
self.clip_version = self.quant_config["special"]["clip_version"]
else:
self.clip_version = "v1"

if (
"special" in self.quant_config
and "clip_sym" in self.quant_config["special"]
):
self.clip_sym = self.quant_config["special"]["clip_sym"]
else:
self.clip_sym = self.wquantizer.sym

if self.clip_version == "v2":
assert self.wquantizer.calib_algo == "learnable"

Expand Down Expand Up @@ -360,7 +380,7 @@ def apply_clip(self, layer, min_val, max_val, idx, n):
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
if self.wquantizer.sym:
if self.clip_sym:
min_val = -max_val

layer.weight.data = torch.clamp(layer.weight.data, min_val, max_val)
Expand All @@ -384,7 +404,7 @@ def get_clip_factor(self, layer, min_val, max_val):
)
org_val_shape = org_max_val.shape

if self.wquantizer.sym:
if self.clip_sym:
abs_max_val = torch.max(org_max_val.abs(), org_min_val.abs())
abs_max_val = abs_max_val.clamp(min=1e-5)
abs_max_val = abs_max_val.reshape(*max_val.shape[:2], -1)
Expand All @@ -404,7 +424,9 @@ def get_clip_factor(self, layer, min_val, max_val):
return up_factor, low_factor

@torch.no_grad()
def auto_clip_layer(self, w, input, n_grid=20, max_shrink=0.5, n_sample_token=512, eps=0.0):
def auto_clip_layer(
self, w, input, n_grid=20, max_shrink=0.5, n_sample_token=512, eps=0.0
):
assert w.dim() == 2

if self.wquantizer.granularity == "per_group":
Expand All @@ -423,7 +445,7 @@ def auto_clip_layer(self, w, input, n_grid=20, max_shrink=0.5, n_sample_token=51
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]

if self.wquantizer.sym:
if self.clip_sym:
org_max_val = w.abs().amax(dim=-1, keepdim=True)
else:
org_max_val = w.amax(dim=-1, keepdim=True)
Expand All @@ -436,7 +458,7 @@ def auto_clip_layer(self, w, input, n_grid=20, max_shrink=0.5, n_sample_token=51
org_out_dict = {}
for i_s in range(int(max_shrink * n_grid)):
if i_s == 0:
if self.clip_version=='v2' and not self.w_only:
if self.clip_version == "v2" and not self.w_only:
i_s += eps
err_mean = 0
for i in range(len(input)):
Expand All @@ -453,7 +475,7 @@ def auto_clip_layer(self, w, input, n_grid=20, max_shrink=0.5, n_sample_token=51

max_val = org_max_val * (1 - i_s / n_grid)

if self.wquantizer.sym:
if self.clip_sym:
min_val = -max_val
else:
min_val = org_min_val * (1 - i_s / n_grid)
Expand Down Expand Up @@ -541,9 +563,9 @@ def copy_tokenizer(self, path):
for substring in self.config.save.get("tokenizer_file_substring", ["token"]):
copy_files(self.config.model.path, path, substring)
logger.info(f"copy tokenizer done --")

@torch.no_grad()
def save_model(self, path):
self.model.get_model().save_pretrained(path)
logger.info(f"save model done --")
self.copy_tokenizer(path)
self.copy_tokenizer(path)
18 changes: 10 additions & 8 deletions llmc/compression/quantization/omniq.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def add_quant_config(self):
self.dtype = torch.float
self.traincast = nullcontext
else:
self.dtype = self.model_dtype
self.dtype = torch.float16
self.traincast = torch.cuda.amp.autocast

self.epochs = self.quant_config["special"]["epochs"]
Expand Down Expand Up @@ -318,15 +318,15 @@ def register_lwc_parameters(self, block, input_feat, idx, init_value=4.0):
torch.ones(
(dim, 1),
device=self.dev,
dtype=self.dtype,
# dtype=self.dtype,
)
* init_value
)
up_param = nn.Parameter(
torch.ones(
(dim, 1),
device=self.dev,
dtype=self.dtype,
# dtype=self.dtype,
)
* init_value
)
Expand Down Expand Up @@ -649,6 +649,13 @@ def smooth_q_k_tmp(self, q_proj, k_proj, scales):
k_proj.tmp_bias = k_proj.tmp_bias * scales.view(-1)

def smooth_q_k_inplace(self, block):

for name, module in block.named_modules():
if isinstance(
module, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm)
):
module.use_tmp_parameter = False

if block.self_attn.q_proj.weight.shape != block.self_attn.k_proj.weight.shape:
return

Expand All @@ -660,11 +667,6 @@ def smooth_q_k_inplace(self, block):
block.self_attn.k_proj.weight.mul_(scales.view(-1, 1))
if block.self_attn.k_proj.bias is not None:
block.self_attn.k_proj.bias.mul_(scales.view(-1))
for name, module in block.named_modules():
if isinstance(
module, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm)
):
module.use_tmp_parameter = False

def w_qdq(self, module):
args = {"lowbound_factor": None, "upbound_factor": None}
Expand Down
6 changes: 5 additions & 1 deletion llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def block_transform(self, block, input_feat, idx, block_kwargs):
params_dict["w_qdq"] = self.w_qdq
self.model.replace_module_block(FakeQuantLinear, block, idx, params_dict)
self.auto_clip(
block, idx, clip_input_feat, n_sample_token=self.config.calib.seq_len, eps=3e-1
block,
idx,
clip_input_feat,
n_sample_token=self.config.calib.seq_len,
eps=3e-1,
)
logger.info(f"auto_clip finished")
else:
Expand Down
20 changes: 14 additions & 6 deletions llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ def __init__(self, bit, symmetric, granularity, **kwargs):
else:
self.calib_algo = "minmax"

if self.sym:
self.max_int = 2 ** (self.bit - 1) - 1
self.min_int = -(2 ** (self.bit - 1))
if "qmax_to_tensor" in self.kwargs and self.kwargs["qmax_to_tensor"]:
if self.sym:
self.max_int = torch.tensor(2 ** (self.bit - 1) - 1).cuda()
self.min_int = torch.tensor(-(2 ** (self.bit - 1))).cuda()
else:
self.max_int = torch.tensor(2**self.bit - 1).cuda()
self.min_int = torch.tensor(0.0).cuda()
else:
self.max_int = 2**self.bit - 1
self.min_int = 0.0
if self.sym:
self.max_int = 2 ** (self.bit - 1) - 1
self.min_int = -(2 ** (self.bit - 1))
else:
self.max_int = 2**self.bit - 1
self.min_int = 0.0

if self.granularity == "per_group":
self.group_size = self.kwargs["group_size"]
Expand Down Expand Up @@ -173,7 +181,7 @@ def reshape_tensor(self, tensor):
else:
t = tensor
elif self.granularity == "per_head":
t = tensor.reshape(self.head_num, -1)
t = tensor.reshape(self.heda_num, -1)
else:
t = tensor
return t
Expand Down

0 comments on commit b824375

Please sign in to comment.