diff --git a/models/transformer.py b/models/transformer.py index 6095e03..0c7796a 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -294,8 +294,7 @@ def forward(self, tgt, reference_points, srcs, src_padding_masks=None, adapt_pos adapt_pos1d=None, posemb_row=None, posemb_col=None, posemb_2d=None): tgt_len = tgt.shape[1] - query_pos = pos2posemb2d(reference_points.squeeze(2)) - query_pos = adapt_pos2d(query_pos) + query_pos = adapt_pos2d(pos2posemb2d(reference_points)) # self attention q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)