Skip to content

Commit

Permalink
fix: add key_padding_mask for the standard attention
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyingming committed Nov 22, 2021
1 parent 1f0b1c1 commit e9f5235
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self, src, padding_mask=None, posemb_row=None, posemb_col=None,posem
else:
src2 = self.self_attn((src + posemb_2d).reshape(bz, h * w, c).transpose(0, 1),
(src + posemb_2d).reshape(bz, h * w, c).transpose(0, 1),
src.reshape(bz, h * w, c).transpose(0, 1))[0].transpose(0, 1).reshape(bz, h, w, c)
src.reshape(bz, h * w, c).transpose(0, 1), key_padding_mask=padding_mask.reshape(bz, h*w))[0].transpose(0, 1).reshape(bz, h, w, c)

src = src + self.dropout1(src2)
src = self.norm1(src)
Expand Down Expand Up @@ -318,7 +318,7 @@ def forward(self, tgt, reference_points, srcs, src_padding_masks=None, adapt_pos
else:
tgt2 = self.cross_attn((tgt + query_pos).repeat(l, 1, 1).transpose(0, 1),
(srcs + posemb_2d).reshape(bz * l, h * w, c).transpose(0,1),
srcs.reshape(bz * l, h * w, c).transpose(0, 1))[0].transpose(0,1)
srcs.reshape(bz * l, h * w, c).transpose(0, 1), key_padding_mask=src_padding_masks.reshape(bz*l, h*w))[0].transpose(0,1)

if l > 1:
tgt2 = self.level_fc(tgt2.reshape(bz, l, tgt_len, c).permute(0, 2, 3, 1).reshape(bz, tgt_len, c * l))
Expand Down

0 comments on commit e9f5235

Please sign in to comment.