Skip to content

Commit

Permalink
refactor: remove invalid loss_cardinality
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyingming committed Jan 7, 2022
1 parent 9e7101d commit c39246b
Showing 1 changed file with 1 addition and 16 deletions.
17 changes: 1 addition & 16 deletions models/anchor_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,6 @@ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses

@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
"""
pred_logits = outputs['pred_logits']
device = pred_logits.device
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
return losses

def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
Expand Down Expand Up @@ -247,7 +233,6 @@ def _get_tgt_permutation_idx(self, indices):
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'cardinality': self.loss_cardinality,
'boxes': self.loss_boxes,
'masks': self.loss_masks
}
Expand Down Expand Up @@ -382,7 +367,7 @@ def build(args):
aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)

losses = ['labels', 'boxes', 'cardinality']
losses = ['labels', 'boxes']
if args.masks:
losses += ["masks"]
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25
Expand Down

0 comments on commit c39246b

Please sign in to comment.