-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add NanoDet-Plus head and Dynamic Soft-label Assigner. (#358)
* [Feature] Add NanoDet-Plus head. * fix param * fix log * fix load * refactor head output * fix load weight and logger * fix unit test and reuse code * fix num_cls * add unit test
- Loading branch information
Showing
11 changed files
with
902 additions
and
156 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from ...loss.iou_loss import bbox_overlaps | ||
from .assign_result import AssignResult | ||
from .base_assigner import BaseAssigner | ||
|
||
|
||
class DynamicSoftLabelAssigner(BaseAssigner): | ||
"""Computes matching between predictions and ground truth with | ||
dynamic soft label assignment. | ||
Args: | ||
topk (int): Select top-k predictions to calculate dynamic k | ||
best matchs for each gt. Default 13. | ||
iou_factor (float): The scale factor of iou cost. Default 3.0. | ||
""" | ||
|
||
def __init__(self, topk=13, iou_factor=3.0): | ||
self.topk = topk | ||
self.iou_factor = iou_factor | ||
|
||
def assign( | ||
self, | ||
pred_scores, | ||
priors, | ||
decoded_bboxes, | ||
gt_bboxes, | ||
gt_labels, | ||
): | ||
"""Assign gt to priors with dynamic soft label assignment. | ||
Args: | ||
pred_scores (Tensor): Classification scores of one image, | ||
a 2D-Tensor with shape [num_priors, num_classes] | ||
priors (Tensor): All priors of one image, a 2D-Tensor with shape | ||
[num_priors, 4] in [cx, xy, stride_w, stride_y] format. | ||
decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape | ||
[num_priors, 4] in [tl_x, tl_y, br_x, br_y] format. | ||
gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor | ||
with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. | ||
gt_labels (Tensor): Ground truth labels of one image, a Tensor | ||
with shape [num_gts]. | ||
Returns: | ||
:obj:`AssignResult`: The assigned result. | ||
""" | ||
INF = 100000000 | ||
num_gt = gt_bboxes.size(0) | ||
num_bboxes = decoded_bboxes.size(0) | ||
|
||
# assign 0 by default | ||
assigned_gt_inds = decoded_bboxes.new_full((num_bboxes,), 0, dtype=torch.long) | ||
|
||
prior_center = priors[:, :2] | ||
lt_ = prior_center[:, None] - gt_bboxes[:, :2] | ||
rb_ = gt_bboxes[:, 2:] - prior_center[:, None] | ||
|
||
deltas = torch.cat([lt_, rb_], dim=-1) | ||
is_in_gts = deltas.min(dim=-1).values > 0 | ||
valid_mask = is_in_gts.sum(dim=1) > 0 | ||
|
||
valid_decoded_bbox = decoded_bboxes[valid_mask] | ||
valid_pred_scores = pred_scores[valid_mask] | ||
num_valid = valid_decoded_bbox.size(0) | ||
|
||
if num_gt == 0 or num_bboxes == 0 or num_valid == 0: | ||
# No ground truth or boxes, return empty assignment | ||
max_overlaps = decoded_bboxes.new_zeros((num_bboxes,)) | ||
if num_gt == 0: | ||
# No truth, assign everything to background | ||
assigned_gt_inds[:] = 0 | ||
if gt_labels is None: | ||
assigned_labels = None | ||
else: | ||
assigned_labels = decoded_bboxes.new_full( | ||
(num_bboxes,), -1, dtype=torch.long | ||
) | ||
return AssignResult( | ||
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels | ||
) | ||
|
||
pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes) | ||
iou_cost = -torch.log(pairwise_ious + 1e-7) | ||
|
||
gt_onehot_label = ( | ||
F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1]) | ||
.float() | ||
.unsqueeze(0) | ||
.repeat(num_valid, 1, 1) | ||
) | ||
valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) | ||
|
||
soft_label = gt_onehot_label * pairwise_ious[..., None] | ||
scale_factor = soft_label - valid_pred_scores | ||
|
||
cls_cost = F.binary_cross_entropy( | ||
valid_pred_scores, soft_label, reduction="none" | ||
) * scale_factor.abs().pow(2.0) | ||
|
||
cls_cost = cls_cost.sum(dim=-1) | ||
|
||
cost_matrix = cls_cost + iou_cost * self.iou_factor | ||
|
||
matched_pred_ious, matched_gt_inds = self.dynamic_k_matching( | ||
cost_matrix, pairwise_ious, num_gt, valid_mask | ||
) | ||
|
||
# convert to AssignResult format | ||
assigned_gt_inds[valid_mask] = matched_gt_inds + 1 | ||
assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1) | ||
assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long() | ||
max_overlaps = assigned_gt_inds.new_full( | ||
(num_bboxes,), -INF, dtype=torch.float32 | ||
) | ||
max_overlaps[valid_mask] = matched_pred_ious | ||
return AssignResult( | ||
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels | ||
) | ||
|
||
def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask): | ||
"""Use sum of topk pred iou as dynamic k. Refer from OTA and YOLOX. | ||
Args: | ||
cost (Tensor): Cost matrix. | ||
pairwise_ious (Tensor): Pairwise iou matrix. | ||
num_gt (int): Number of gt. | ||
valid_mask (Tensor): Mask for valid bboxes. | ||
""" | ||
matching_matrix = torch.zeros_like(cost) | ||
# select candidate topk ious for dynamic-k calculation | ||
candidate_topk = min(self.topk, pairwise_ious.size(0)) | ||
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) | ||
# calculate dynamic k for each gt | ||
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) | ||
for gt_idx in range(num_gt): | ||
_, pos_idx = torch.topk( | ||
cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False | ||
) | ||
matching_matrix[:, gt_idx][pos_idx] = 1.0 | ||
|
||
del topk_ious, dynamic_ks, pos_idx | ||
|
||
prior_match_gt_mask = matching_matrix.sum(1) > 1 | ||
if prior_match_gt_mask.sum() > 0: | ||
cost_min, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1) | ||
matching_matrix[prior_match_gt_mask, :] *= 0.0 | ||
matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0 | ||
# get foreground mask inside box and center prior | ||
fg_mask_inboxes = matching_matrix.sum(1) > 0.0 | ||
valid_mask[valid_mask.clone()] = fg_mask_inboxes | ||
|
||
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) | ||
matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes] | ||
return matched_pred_ious, matched_gt_inds |
Oops, something went wrong.