Skip to content

Commit

Permalink
perf(triplet_margin_loss): update
Browse files Browse the repository at this point in the history
1. use input_dict[KEY_OUTPUT] to compute loss
2. reshape embeddings be [N, D]
  • Loading branch information
zjykzj committed Aug 3, 2022
1 parent 426c64c commit 21be529
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions simpleir/criterion/triplet_margin_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.nn as nn
from torch import Tensor

from simpleir.configs.key_words import KEY_INPUT, KEY_FEAT
from zcls2.config.key_word import KEY_OUTPUT

__all__ = ['triplet_margin_loss', 'TripletMarginLoss']

Expand All @@ -36,10 +36,11 @@ def _forward(self, embeddings, labels):
return self.loss_fn(labels, embeddings, self.margin, self.p)

def forward(self, input_dict: Dict, target: Tensor) -> Tensor:
embeddings = input_dict[KEY_FEAT]
labels = input_dict[KEY_INPUT]
embeddings = input_dict[KEY_OUTPUT]
if len(embeddings.shape) != 2:
embeddings = embeddings.reshape(embeddings.shape[0], -1)

triplet_loss, fraction_positive_triplets = self._forward(embeddings, labels)
triplet_loss, fraction_positive_triplets = self._forward(embeddings, target)
return triplet_loss


Expand Down

0 comments on commit 21be529

Please sign in to comment.