Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

加載checkpoints #29

Open
huanghmingyue opened this issue Mar 6, 2024 · 9 comments
Open

加載checkpoints #29

huanghmingyue opened this issue Mar 6, 2024 · 9 comments

Comments

@huanghmingyue
Copy link

您好!
請問在compression之後,模型的結構也發生改變,是不是不能用原來的build model函數加載模型的checkpoints?

@sdc17
Copy link
Owner

sdc17 commented Mar 6, 2024

您好,是的。
对于加载压缩后的checkpoints,代码中先使用原来的build model创建模型,然后通过prune_if_compressed方法将原模型的各个参数形状修改成和checkpoint匹配的形状,最后加载checkpoint的参数。

如果想直接使用原来的build model函数加载压缩后的checkpoints,可以改写各个模型的compress方法,使得模型compression结束后保存checkpoint时不实际删除模型中多余的参数,而是像非结构化剪枝一样将它们置为0。这样可以保持参数矩阵形状不变,从而可以使用原来的build model函数加载checkpoint,但缺点是没有实际压缩模型的大小。

@huanghmingyue
Copy link
Author

謝謝您的解答!如果想把Upop用到其他模型上時,請問以下參數應該如何設置:
1、parser.add_argument('--w_sp_attn', default=4.8e-3, type=float, help='regularization coefficient for attn') parser.add_argument('--w_sp_mlp', default=2e-4, type=float, help='regularization coefficient for mlp')
w_sp_attnw_sp_mlp的值是怎麽選擇的?
2、epochs-searchinterval是根據什麽設置呢?
3、compression_weight[indices < alpha_grad_attn.numel()] = 9234/769這裏的9234/769是什麽意思,請問其他的模型應該設置為什麽?

@sdc17
Copy link
Owner

sdc17 commented Mar 10, 2024

  1. --w_sp_attn--w_sp_mlp分别用来控制attention和ffn上learnable mask的loss值。遵循了两点设置:(1) 使这两个loss值在训练开始时是相等的,这两个值在训练过程中也会打印出来。(2) 让它们的值在search阶段结束后和模型原有损失函数的loss值在同一数量级上。具体数值没有仔细调过,也可尝试下其他设置。

  2. epochs-search与具体任务有关。多模态原本用于训练的epoch较少,因此直接把搜索的epoch设置为和训练相同。单模态原本用于训练的epoch较多,因此搜索的epoch仅用了训练的约1/5。条件允许的话一般epochs-search越多越好。interval论文里有解释,它代表的是间隔多少个iteration更新一次learnable mask的参数。建议大约设置为1%的压缩率对应的iteration数。例如要在1000次iteration里实现50%的压缩,可设置为1000/50=20上下。

  3. UPop实现的是结构化剪枝,这个数字是为了让attention和ffn的learnable mask中每一个位置对应的实际覆盖的参数数量相同,9234/769的得到过程为:

    • 分子(attn):[384(attn中qkv输入一行的参数个数为384)+1(attn中qkv输入一行对应的bias参数数量为1)] $\times$ [1(query)+1(key)+1(value)] $\times$ 6(heads数量) + 384(attn中proj输出一列的参数个数为384) $\times$ 6(heads数量) = (384+1) $\times$ (1+1+1) $\times$ 6+384 $\times$ 6 = 9234
    • 分母(ffn):384(ffn中fc1输入一行的参数个数为384)+1(ffn中fc1输入一行对应的bias参数数量为1)+384(ffn中fc2输出一列的参数个数为384) = 384+1+384 = 769

    所以总共为9234/769。对于不同的模型,可参照上述过程,根据heads数量,attention及ffn的参数矩阵的输入输出大小计算调整这一比例关系,以使得不同结构的learnable mask中每一个位置对应的实际覆盖的参数数量相同。

@huanghmingyue
Copy link
Author

huanghmingyue commented Mar 23, 2024

感谢您的解答!
请问把UPop方法应用到Multi_Scale_Deformable_Attention中是否支持呢,例如修改__init__forward
1.修改 __init__

class MSDeformAttn(nn.Module):
    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, search=False):
      
     #  Multi-Scale Deformable Attention Module
     #   :param d_model      hidden dimension
     #   :param n_levels     number of feature levels
     #   :param n_heads      number of attention heads
     #  :param n_points     number of sampling points per attention head per feature level
 
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
        _d_per_head = d_model // n_heads
        # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
        if not _is_power_of_2(_d_per_head):
            warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 ")

        self.im2col_step = 64

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)
# 1.添加以下----------------------------------------------------------------------------------------
        if search:
            self.alpha = nn.Parameter(torch.ones(1, 1, 1, _d_per_head))
# ------------------------------------------------------------------------------------------------

        self._reset_parameters()


2.修改forward

def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
        """
        :param query                       (N, Length_{query}, C)
        :param reference_points            (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
                                        or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
        :param input_flatten               (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
        :param input_spatial_shapes        (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
        :param input_level_start_index     (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
        :param input_padding_mask          (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements

        :return output                     (N, Length_{query}, C)
        """
        N, Len_q, _ = query.shape
        N, Len_in, _ = input_flatten.shape
        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in

        value = self.value_proj(input_flatten)
        if input_padding_mask is not None:
            value = value.masked_fill(input_padding_mask[..., None], float(0))
        value = value.view(N, Len_in, self.n_heads, -1)
        
        # ---2.添加以下 --------------------------------------------------------------------------------------------        
        if hasattr(self, 'alpha'):
            value = value * self.alpha
        # --------------------------------------------------------------------------------------------------

        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
        attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
        attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
        # N, Len_q, n_heads, n_levels, n_points, 2
        if reference_points.shape[-1] == 2:
            offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
            sampling_locations = reference_points[:, :, None, :, None, :] \
                                 + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                                 + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
        else:
            raise ValueError(
                'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))

期待您的解答。

@sdc17
Copy link
Owner

sdc17 commented Mar 27, 2024

基于Transformer的模型应该都可以。
注意初始化的mask参数的shape,其在forward过程中能正常乘到value上即可。

@huanghmingyue
Copy link
Author

好的明白了,非常感谢您百忙之中详细地回复。您的解答对我帮助很大!

@huanghmingyue
Copy link
Author

论文里提到累计梯度是梯度从iteration 0到current iteration t之和,请问这是在代码的哪一部分中实现的,代码中用.grad获得梯度,那么累计梯度是怎么获得的呢?感谢您的解答。

@sdc17
Copy link
Owner

sdc17 commented Mar 29, 2024

不主动清除梯度即可。mask的参数没有被包括在原模型的优化器中,例如:

optimizer = torch.optim.AdamW(
params=[{'params':[param for name, param in list(search_model.named_parameters()) if not ('alpha' in name)]}],
lr=config['init_lr'],
weight_decay=config['weight_decay']
)

,它的梯度.grad会随着迭代的进行自动累加。

@huanghmingyue
Copy link
Author

huanghmingyue commented Apr 1, 2024

谢谢您回复!还想向您请教:
在Progressive Pruning中,每个iteration对应的 $M^{t}$ 可以根据累计梯度得到,再平滑得到mask $\zeta^{t}$,但是在不同iteration中,各部分的累计梯度情况可能不一样?那不同iteration更新的可能是mask中的不同位置,那是不是存在以下2种情况:

(1) 有一些位置在最初几个search epochs时从1变为0.几,但没有在后续继续降低到0;如果在此时mask 已被应用于网络,该层网络就需要在后续训练中重新更正/学习参数?
(2) 有一些位置在最后几个search epochs时才出现累计梯度很大,就没有经过很多次平滑的mask训练;这可能会引起模型在最后的epoch中的性能波动吗?
不知是否可以这样理解,初次接触这个领域希望您指正理解错误的地方。非常感谢!

另外,我想把Upop用在基于Transformer的目标检测模型上(DETR类的模型),去修剪Transformer部分,应该也是适用的?
这个模型的encoder的注意力模块使用deformable attention,decoder的自注意力模块使用标准的多头自注意力,交叉注意力模块使用deformable attention,那么在get_sparsity_loss函数中,应该把这两种不同的注意力分开返回loss吗(如下方代码),还是sparsity_loss_multi_attnsparsity_loss_dattn加在一起?

def get_sparsity_loss(model):
    sparsity_loss_dattn, sparsity_loss_ffn, sparsity_loss_multi_attn= 0, 0, 0
    for i in range(model.transformer.encoder.num_layers):
        sparsity_loss_dattn += torch.sum(torch.abs(getattr(model.transformer.encoder.layers, str(i)).self_attn.alpha))
        sparsity_loss_ffn += torch.sum(torch.abs(getattr(model.transformer.encoder.layers, str(i)).ffn.alpha))
    for i in range(model.transformer.decoder.num_layers):
        sparsity_loss_dattn += torch.sum(torch.abs(getattr(model.transformer.decoder.layers, str(i)).cross_attn.alpha))
        sparsity_loss_ffn += torch.sum(torch.abs(getattr(model.transformer.decoder.layers, str(i)).ffn.alpha))
        sparsity_loss_multi_attn += torch.sum(torch.abs(getattr(model.transformer.decoder.layers, str(i)).self_attn.alpha))
    return sparsity_loss_dattn, sparsity_loss_ffn, sparsity_loss_multi_attn

非常感谢您的解答!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants