Skip to content

Commit

Permalink
增加取encoder输出每个位置的平均值来作为分类器的输入
Browse files Browse the repository at this point in the history
  • Loading branch information
moon-hotel committed Jul 1, 2021
1 parent c7dbdd2 commit eb3c8bd
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ClassificationModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ def forward(self,
src_embed = self.pos_embedding(src_embed) # [src_len, batch_size, embed_dim]
memory = self.encoder(src=src_embed,
mask=src_mask,
src_key_padding_mask=src_key_padding_mask) # [src_len,batch_size,embed_dim]
src_key_padding_mask=src_key_padding_mask)
# [src_len,batch_size,embed_dim]
if concat_type == 'sum':
memory = torch.sum(memory, dim=0)
elif concat_type == 'avg':
memory = torch.sum(memory, dim=0) / memory.size(0)
else:
memory = memory[-1, ::] # 取最后一个时刻
# [src_len, batch_size, num_heads * kdim] <==> [src_len,batch_size,embed_dim]
Expand Down

0 comments on commit eb3c8bd

Please sign in to comment.