Skip to content

Commit

Permalink
Fixed small import errors in s2sd.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Confusezius committed Aug 9, 2022
1 parent 71f5c77 commit b07ac5b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions criteria/s2sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, opt):
# Initialize all target criteria. As each criterion may require its
# separate set of trainable parameters, several instances have to be created.
old_embed_dim = copy.deepcopy(opt.embed_dim)
self.target_criteria = nn.ModuleList()
self.target_criteria = torch.nn.ModuleList()
for t_dim in opt.loss_s2sd_target_dims:
opt.embed_dim = t_dim

Expand Down Expand Up @@ -127,8 +127,8 @@ class [0,...,C-1], shape: (BS x 1)
# If required, use combined global max- and average pooling to produce
# the feature space.
if self.pool_aggr:
avg_batch_features = nn.AdaptiveAvgPool2d(1)(batch_features).view(
bs, -1) + nn.AdaptiveMaxPool2d(1)(batch_features).view(bs, -1)
avg_batch_features = torch.nn.AdaptiveAvgPool2d(1)(batch_features).view(
bs, -1) + torch.nn.AdaptiveMaxPool2d(1)(batch_features).view(bs, -1)
else:
avg_batch_features = avg_batch_features.view(bs, -1)

Expand Down

0 comments on commit b07ac5b

Please sign in to comment.