Skip to content

Commit

Permalink
简化多头注意力机制实现
Browse files Browse the repository at this point in the history
  • Loading branch information
moon-hotel committed Jan 18, 2022
1 parent a976086 commit 007d33f
Showing 1 changed file with 24 additions and 28 deletions.
52 changes: 24 additions & 28 deletions MyTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import torch

is_print_shape = False
is_print_shape = True


class MyTransformer(nn.Module):
Expand Down Expand Up @@ -85,8 +85,6 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
:param nhead: 多头注意力机制中多头的数量,论文默认为值 8
:param dim_feedforward: 全连接中向量的维度,论文默认值为 2048
:param dropout: 丢弃率,论文中的默认值为 0.1
"""
self.self_attn = MyMultiheadAttention(d_model, nhead, dropout=dropout)

Expand Down Expand Up @@ -128,7 +126,6 @@ def __init__(self, encoder_layer, num_layers, norm=None):
encoder_layer: 就是包含有多头注意力机制的一个编码层
num_layers: 克隆得到多个encoder layers 论文中默认为6
norm: 归一化层
"""
self.layers = _get_clones(encoder_layer, num_layers) # 克隆得到多个encoder layers 论文中默认为6
self.num_layers = num_layers
Expand Down Expand Up @@ -275,11 +272,10 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
assert self.head_dim * num_heads == self.embed_dim, "embed_dim 除以 num_heads必须为整数"
# 上面的限制条件就是论文中的 d_k = d_v = d_model/n_head 条件

self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) # embed_dim = kdim * num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # embed_dim = kdim * num_heads
# 这里第二个维度之所以是embed_dim,实际上这里是同时初始化了num_heads个W_q堆叠起来的, 也就是num_heads个头
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) # W_k, embed_dim = kdim * num_heads
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) # W_v, embed_dim = vdim * num_heads

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # W_k, embed_dim = kdim * num_heads
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # W_v, embed_dim = vdim * num_heads
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# 最后将所有的Z组合起来的时候,也是一次性完成, embed_dim = vdim * num_heads
self._reset_parameters()
Expand All @@ -289,9 +285,9 @@ def _reset_parameters(self):
以特定方式来初始化参数
:return:
"""
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)

def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
"""
Expand All @@ -308,12 +304,13 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
attn_output_weights: # [batch_size, tgt_len, src_len]
"""
return multi_head_attention_forward(query, key, value, self.num_heads,
self.dropout, self.out_proj.weight, self.out_proj.bias,
self.dropout,
out_proj=self.out_proj,
training=self.training,
key_padding_mask=key_padding_mask,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
q_proj=self.q_proj,
k_proj=self.k_proj,
v_proj=self.v_proj,
attn_mask=attn_mask)


Expand All @@ -322,40 +319,39 @@ def multi_head_attention_forward(query, # [tgt_len,batch_size, embed_dim]
value, # [src_len, batch_size, embed_dim]
num_heads,
dropout_p,
out_proj_weight, # [embed_dim = vdim * num_heads, embed_dim = vdim * num_heads]
out_proj_bias,
out_proj, # [embed_dim = vdim * num_heads, embed_dim = vdim * num_heads]
training=True,
key_padding_mask=None, # [batch_size,src_len/tgt_len]
q_proj_weight=None, # [embed_dim,kdim * num_heads]
k_proj_weight=None, # [embed_dim, kdim * num_heads]
v_proj_weight=None, # [embed_dim, vdim * num_heads]
q_proj=None, # [embed_dim,kdim * num_heads]
k_proj=None, # [embed_dim, kdim * num_heads]
v_proj=None, # [embed_dim, vdim * num_heads]
attn_mask=None, # [tgt_len,src_len] or [num_heads*batch_size,tgt_len, src_len]
):
q = F.linear(query, q_proj_weight)
q = q_proj(query)
# [tgt_len,batch_size, embed_dim] x [embed_dim,kdim * num_heads] = [tgt_len,batch_size,kdim * num_heads]

k = F.linear(key, k_proj_weight)
k = k_proj(key)
# [src_len, batch_size, embed_dim] x [embed_dim, kdim * num_heads] = [src_len, batch_size, kdim * num_heads]

v = F.linear(value, v_proj_weight)
v = v_proj(value)
# [src_len, batch_size, embed_dim] x [embed_dim, vdim * num_heads] = [src_len, batch_size, vdim * num_heads]
if is_print_shape:
print("" + "=" * 80)
print("进入多头注意力计算:")
print(
f"\t 多头num_heads = {num_heads}, d_model={query.size(-1)}, d_k = d_v = d_model/num_heads={query.size(-1) // num_heads}")
print(f"\t query的shape([tgt_len, batch_size, embed_dim]):{query.shape}")
print(f"\t W_q 的shape([embed_dim,kdim * num_heads]):{q_proj_weight.shape}")
print(f"\t W_q 的shape([embed_dim,kdim * num_heads]):{q_proj.weight.shape}")
print(f"\t Q 的shape([tgt_len, batch_size,kdim * num_heads]):{q.shape}")
print("\t" + "-" * 70)

print(f"\t key 的shape([src_len,batch_size, embed_dim]):{key.shape}")
print(f"\t W_k 的shape([embed_dim,kdim * num_heads]):{k_proj_weight.shape}")
print(f"\t W_k 的shape([embed_dim,kdim * num_heads]):{k_proj.weight.shape}")
print(f"\t K 的shape([src_len,batch_size,kdim * num_heads]):{k.shape}")
print("\t" + "-" * 70)

print(f"\t value的shape([src_len,batch_size, embed_dim]):{value.shape}")
print(f"\t W_v 的shape([embed_dim,vdim * num_heads]):{v_proj_weight.shape}")
print(f"\t W_v 的shape([embed_dim,vdim * num_heads]):{v_proj.weight.shape}")
print(f"\t V 的shape([src_len,batch_size,vdim * num_heads]):{v.shape}")
print("\t" + "-" * 70)
print("\t ***** 注意,这里的W_q, W_k, W_v是多个head同时进行计算的. 因此,Q,K,V分别也是包含了多个head的q,k,v堆叠起来的结果 *****")
Expand Down Expand Up @@ -409,11 +405,11 @@ def multi_head_attention_forward(query, # [tgt_len,batch_size, embed_dim]
# 再view成 [tgt_len,batch_size,num_heads*kdim]
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)

Z = F.linear(attn_output, out_proj_weight, out_proj_bias)
Z = out_proj(attn_output)
# 这里就是多个z 线性组合成Z [tgt_len,batch_size,embed_dim]
if is_print_shape:
print(f"\t 多头注意力中,多头计算结束后的形状(堆叠)为([tgt_len,batch_size,num_heads*kdim]){attn_output.shape}")
print(f"\t 多头计算结束后,再进行线性变换时的权重W_o的形状为([num_heads*vdim, num_heads*vdim ]){out_proj_weight.shape}")
print(f"\t 多头计算结束后,再进行线性变换时的权重W_o的形状为([num_heads*vdim, num_heads*vdim ]){out_proj.weight.shape}")
print(f"\t 多头线性变化后的形状为([tgt_len,batch_size,embed_dim]) {Z.shape}")
return Z, attn_output_weights.sum(dim=1) / num_heads # average attention weights over heads

Expand Down

0 comments on commit 007d33f

Please sign in to comment.