Skip to content

Commit

Permalink
Fix #1
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentin Gabeur committed Oct 8, 2020
1 parent 0d848cd commit e24a84d
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,15 @@ def __init__(self,
self.text_gu = nn.ModuleDict()
for mod in self.modalities:
if self.txt_pro == 'gbn':
self.text_gu[mod] = GatedEmbeddingUnit(
text_dim, same_dim, use_bn=True, normalize=self.normalize_experts)
self.text_gu[mod] = GatedEmbeddingUnit(text_dim,
same_dim,
use_bn=True,
normalize=self.normalize_experts)
elif self.txt_pro == 'gem':
self.text_gu[mod] = GatedEmbeddingUnit(
text_dim, same_dim, use_bn=False, normalize=self.normalize_experts)
self.text_gu[mod] = GatedEmbeddingUnit(text_dim,
same_dim,
use_bn=False,
normalize=self.normalize_experts)
elif self.txt_pro == 'lin':
self.text_gu[mod] = ReduceDim(text_dim, same_dim)

Expand Down Expand Up @@ -364,12 +368,11 @@ def forward(self,
position_ids = th.stack(position_ids_list, dim=1).to(device)
attention_mask = th.stack(attention_mask_list, dim=1).to(device)

txt_bert_output = self.txt_bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=None)
txt_bert_output = self.txt_bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=None)
last_layer = txt_bert_output[0]

if self.post_agg == 'cls':
Expand Down Expand Up @@ -571,12 +574,11 @@ def forward(self,
token_ids = token_ids.view(b * captions_per_video, max_text_words,
feat_dim)

vid_bert_output = self.vid_bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
features=features)
vid_bert_output = self.vid_bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
features=features)

last_layer = vid_bert_output[0]
vid_embd = last_layer[:, 0]
Expand Down Expand Up @@ -619,8 +621,8 @@ def forward(self,
if self.normalize_experts:
for _, modality in enumerate(self.modalities):
experts[modality] = nn.functional.normalize(experts[modality], dim=-1)
text_embd[modality] = nn.functional.normalize(
text_embd[modality], dim=-1)
text_embd[modality] = nn.functional.normalize(text_embd[modality],
dim=-1)

if self.training:
merge_caption_similiarities = 'avg'
Expand Down Expand Up @@ -696,7 +698,7 @@ def forward(self, x):
x = self.fc(x)
x = self.cg(x)
if self.normalize:
x = F.normalize(x)
x = F.normalize(x, dim=-1)
return x


Expand All @@ -708,7 +710,7 @@ def __init__(self, input_dimension, output_dimension, use_bn):

def forward(self, x):
x = self.cg(x)
x = F.normalize(x)
x = F.normalize(x, dim=-1)
return x


Expand All @@ -720,7 +722,7 @@ def __init__(self, input_dimension, output_dimension):

def forward(self, x):
x = self.fc(x)
x = F.normalize(x)
x = F.normalize(x, dim=-1)
return x


Expand Down Expand Up @@ -756,7 +758,7 @@ def __init__(self, output_dimension):

def forward(self, x, mask):
x = self.cg(x, mask)
x = F.normalize(x)
x = F.normalize(x, dim=-1)
return x


Expand Down

0 comments on commit e24a84d

Please sign in to comment.