Skip to content

Commit

Permalink
fix: line order in RobustLoss (Parskatt#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
lnexenl committed Jun 11, 2024
1 parent 884d991 commit 2d869bb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions roma/losses/robust_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
if not torch.any(cls_loss):
cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere

certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)

losses = {
f"gm_certainty_loss_{scale}": certainty_loss.mean(),
f"gm_cls_loss_{scale}": cls_loss.mean(),
Expand Down

0 comments on commit 2d869bb

Please sign in to comment.