diff --git a/nanodet/model/head/__init__.py b/nanodet/model/head/__init__.py index 1093fba1d..236a10a3a 100644 --- a/nanodet/model/head/__init__.py +++ b/nanodet/model/head/__init__.py @@ -2,6 +2,7 @@ from .gfl_head import GFLHead from .nanodet_head import NanoDetHead +from .nanodet_plus_head import NanoDetPlusHead def build_head(cfg): @@ -11,5 +12,7 @@ def build_head(cfg): return GFLHead(**head_cfg) elif name == "NanoDetHead": return NanoDetHead(**head_cfg) + elif name == "NanoDetPlusHead": + return NanoDetPlusHead(**head_cfg) else: raise NotImplementedError diff --git a/nanodet/model/head/assigner/dsl_assigner.py b/nanodet/model/head/assigner/dsl_assigner.py new file mode 100644 index 000000000..e74dc0854 --- /dev/null +++ b/nanodet/model/head/assigner/dsl_assigner.py @@ -0,0 +1,154 @@ +import torch +import torch.nn.functional as F + +from ...loss.iou_loss import bbox_overlaps +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +class DynamicSoftLabelAssigner(BaseAssigner): + """Computes matching between predictions and ground truth with + dynamic soft label assignment. + + Args: + topk (int): Select top-k predictions to calculate dynamic k + best matchs for each gt. Default 13. + iou_factor (float): The scale factor of iou cost. Default 3.0. + """ + + def __init__(self, topk=13, iou_factor=3.0): + self.topk = topk + self.iou_factor = iou_factor + + def assign( + self, + pred_scores, + priors, + decoded_bboxes, + gt_bboxes, + gt_labels, + ): + """Assign gt to priors with dynamic soft label assignment. + Args: + pred_scores (Tensor): Classification scores of one image, + a 2D-Tensor with shape [num_priors, num_classes] + priors (Tensor): All priors of one image, a 2D-Tensor with shape + [num_priors, 4] in [cx, xy, stride_w, stride_y] format. + decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape + [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format. + gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor + with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth labels of one image, a Tensor + with shape [num_gts]. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + INF = 100000000 + num_gt = gt_bboxes.size(0) + num_bboxes = decoded_bboxes.size(0) + + # assign 0 by default + assigned_gt_inds = decoded_bboxes.new_full((num_bboxes,), 0, dtype=torch.long) + + prior_center = priors[:, :2] + lt_ = prior_center[:, None] - gt_bboxes[:, :2] + rb_ = gt_bboxes[:, 2:] - prior_center[:, None] + + deltas = torch.cat([lt_, rb_], dim=-1) + is_in_gts = deltas.min(dim=-1).values > 0 + valid_mask = is_in_gts.sum(dim=1) > 0 + + valid_decoded_bbox = decoded_bboxes[valid_mask] + valid_pred_scores = pred_scores[valid_mask] + num_valid = valid_decoded_bbox.size(0) + + if num_gt == 0 or num_bboxes == 0 or num_valid == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = decoded_bboxes.new_zeros((num_bboxes,)) + if num_gt == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = decoded_bboxes.new_full( + (num_bboxes,), -1, dtype=torch.long + ) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels + ) + + pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes) + iou_cost = -torch.log(pairwise_ious + 1e-7) + + gt_onehot_label = ( + F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1]) + .float() + .unsqueeze(0) + .repeat(num_valid, 1, 1) + ) + valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) + + soft_label = gt_onehot_label * pairwise_ious[..., None] + scale_factor = soft_label - valid_pred_scores + + cls_cost = F.binary_cross_entropy( + valid_pred_scores, soft_label, reduction="none" + ) * scale_factor.abs().pow(2.0) + + cls_cost = cls_cost.sum(dim=-1) + + cost_matrix = cls_cost + iou_cost * self.iou_factor + + matched_pred_ious, matched_gt_inds = self.dynamic_k_matching( + cost_matrix, pairwise_ious, num_gt, valid_mask + ) + + # convert to AssignResult format + assigned_gt_inds[valid_mask] = matched_gt_inds + 1 + assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1) + assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long() + max_overlaps = assigned_gt_inds.new_full( + (num_bboxes,), -INF, dtype=torch.float32 + ) + max_overlaps[valid_mask] = matched_pred_ious + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels + ) + + def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask): + """Use sum of topk pred iou as dynamic k. Refer from OTA and YOLOX. + + Args: + cost (Tensor): Cost matrix. + pairwise_ious (Tensor): Pairwise iou matrix. + num_gt (int): Number of gt. + valid_mask (Tensor): Mask for valid bboxes. + """ + matching_matrix = torch.zeros_like(cost) + # select candidate topk ious for dynamic-k calculation + candidate_topk = min(self.topk, pairwise_ious.size(0)) + topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) + # calculate dynamic k for each gt + dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) + for gt_idx in range(num_gt): + _, pos_idx = torch.topk( + cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False + ) + matching_matrix[:, gt_idx][pos_idx] = 1.0 + + del topk_ious, dynamic_ks, pos_idx + + prior_match_gt_mask = matching_matrix.sum(1) > 1 + if prior_match_gt_mask.sum() > 0: + cost_min, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1) + matching_matrix[prior_match_gt_mask, :] *= 0.0 + matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0 + # get foreground mask inside box and center prior + fg_mask_inboxes = matching_matrix.sum(1) > 0.0 + valid_mask[valid_mask.clone()] = fg_mask_inboxes + + matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) + matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes] + return matched_pred_ious, matched_gt_inds diff --git a/nanodet/model/head/gfl_head.py b/nanodet/model/head/gfl_head.py index ae16e45e4..d1557f55d 100644 --- a/nanodet/model/head/gfl_head.py +++ b/nanodet/model/head/gfl_head.py @@ -1,3 +1,5 @@ +import math + import cv2 import numpy as np import torch @@ -59,8 +61,9 @@ def forward(self, x): x (Tensor): Integral result of box locations, i.e., distance offsets from the box center in four directions, shape (N, 4). """ - x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1) - x = F.linear(x, self.project.type_as(x)).reshape(-1, 4) + shape = x.size() + x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) + x = F.linear(x, self.project.type_as(x)).reshape(*shape[:-1], 4) return x @@ -183,31 +186,39 @@ def init_weights(self): normal_init(self.gfl_reg, std=0.01) def forward(self, feats): - return multi_apply(self.forward_single, feats, self.scales) - - def forward_single(self, x, scale): - cls_feat = x - reg_feat = x - for cls_conv in self.cls_convs: - cls_feat = cls_conv(cls_feat) - for reg_conv in self.reg_convs: - reg_feat = reg_conv(reg_feat) - cls_score = self.gfl_cls(cls_feat) - bbox_pred = scale(self.gfl_reg(reg_feat)).float() - return cls_score, bbox_pred + outputs = [] + for x, scale in zip(feats, self.scales): + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.gfl_cls(cls_feat) + bbox_pred = scale(self.gfl_reg(reg_feat)).float() + output = torch.cat([cls_score, bbox_pred], dim=1) + outputs.append(output.flatten(start_dim=2)) + outputs = torch.cat(outputs, dim=2).permute(0, 2, 1) + return outputs def loss(self, preds, gt_meta): - cls_scores, bbox_preds = preds - batch_size = cls_scores[0].shape[0] - device = cls_scores[0].device + cls_scores, bbox_preds = preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1 + ) + device = cls_scores.device gt_bboxes = gt_meta["gt_bboxes"] gt_labels = gt_meta["gt_labels"] + input_height, input_width = gt_meta["img"].shape[2:] gt_bboxes_ignore = None - featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + featmap_sizes = [ + (math.ceil(input_height / stride), math.ceil(input_width) / stride) + for stride in self.strides + ] cls_reg_targets = self.target_assign( - batch_size, + cls_scores, + bbox_preds, featmap_sizes, gt_bboxes, gt_bboxes_ignore, @@ -218,6 +229,8 @@ def loss(self, preds, gt_meta): return None ( + cls_preds_list, + reg_preds_list, grid_cells_list, labels_list, label_weights_list, @@ -233,8 +246,8 @@ def loss(self, preds, gt_meta): losses_qfl, losses_bbox, losses_dfl, avg_factor = multi_apply( self.loss_single, grid_cells_list, - cls_scores, - bbox_preds, + cls_preds_list, + reg_preds_list, labels_list, label_weights_list, bbox_targets_list, @@ -278,10 +291,9 @@ def loss_single( stride, num_total_samples, ): - grid_cells = grid_cells.reshape(-1, 4) - cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) - bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1)) + cls_score = cls_score.reshape(-1, self.cls_out_channels) + bbox_pred = bbox_pred.reshape(-1, 4 * (self.reg_max + 1)) bbox_targets = bbox_targets.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) @@ -347,7 +359,8 @@ def loss_single( def target_assign( self, - batch_size, + cls_preds, + reg_preds, featmap_sizes, gt_bboxes_list, gt_bboxes_ignore_list, @@ -364,6 +377,7 @@ def target_assign( :param device: pytorch device :return: Assign results of all images. """ + batch_size = cls_preds.shape[0] # get grid cells of one image multi_level_grid_cells = [ self.get_grid_cells( @@ -414,12 +428,16 @@ def target_assign( num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) # merge list of targets tensors into one batch then split to multi levels + mlvl_cls_preds = images_to_levels([c for c in cls_preds], num_level_cells) + mlvl_reg_preds = images_to_levels([r for r in reg_preds], num_level_cells) mlvl_grid_cells = images_to_levels(all_grid_cells, num_level_cells) mlvl_labels = images_to_levels(all_labels, num_level_cells) mlvl_label_weights = images_to_levels(all_label_weights, num_level_cells) mlvl_bbox_targets = images_to_levels(all_bbox_targets, num_level_cells) mlvl_bbox_weights = images_to_levels(all_bbox_weights, num_level_cells) return ( + mlvl_cls_preds, + mlvl_reg_preds, mlvl_grid_cells, mlvl_labels, mlvl_label_weights, @@ -508,7 +526,9 @@ def sample(self, assign_result, gt_bboxes): return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds def post_process(self, preds, meta): - cls_scores, bbox_preds = preds + cls_scores, bbox_preds = preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1 + ) result_list = self.get_bboxes(cls_scores, bbox_preds, meta) det_results = {} warp_matrixes = ( @@ -562,90 +582,55 @@ def show_result( cv2.imshow("det", result) return result - def get_bboxes(self, cls_scores, bbox_preds, img_metas, rescale=False): - - assert len(cls_scores) == len(bbox_preds) - num_levels = len(cls_scores) - device = cls_scores[0].device + def get_bboxes(self, cls_preds, reg_preds, img_metas): + """Decode the outputs to bboxes. + Args: + cls_preds (Tensor): Shape (num_imgs, num_points, num_classes). + reg_preds (Tensor): Shape (num_imgs, num_points, 4 * (regmax + 1)). + img_metas (dict): Dict of image info. + Returns: + results_list (list[tuple]): List of detection bboxes and labels. + """ + device = cls_preds.device + b = cls_preds.shape[0] input_height, input_width = img_metas["img"].shape[2:] - input_shape = [input_height, input_width] + input_shape = (input_height, input_width) - result_list = [] - for img_id in range(cls_scores[0].shape[0]): - cls_score_list = [cls_scores[i][img_id].detach() for i in range(num_levels)] - bbox_pred_list = [bbox_preds[i][img_id].detach() for i in range(num_levels)] - scale_factor = 1 - dets = self.get_bboxes_single( - cls_score_list, - bbox_pred_list, - input_shape, - scale_factor, - device, - rescale, - ) - - result_list.append(dets) - return result_list - - def get_bboxes_single( - self, cls_scores, bbox_preds, img_shape, scale_factor, device, rescale=False - ): - """ - Decode output tensors to bboxes on one image. - :param cls_scores: classification prediction tensors of all stages - :param bbox_preds: regression prediction tensors of all stages - :param img_shape: shape of input image - :param scale_factor: scale factor of boxes - :param device: device of the tensor - :return: predict boxes and labels - """ - assert len(cls_scores) == len(bbox_preds) - mlvl_bboxes = [] - mlvl_scores = [] - for stride, cls_score, bbox_pred in zip(self.strides, cls_scores, bbox_preds): - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - featmap_size = cls_score.size()[-2:] + featmap_sizes = [ + (math.ceil(input_height / stride), math.ceil(input_width) / stride) + for stride in self.strides + ] + # get grid cells of one image + mlvl_center_priors = [] + for i, stride in enumerate(self.strides): y, x = self.get_single_level_center_point( - featmap_size, stride, cls_score.dtype, device, flatten=True + featmap_sizes[i], stride, torch.float32, device ) - center_points = torch.stack([x, y], dim=-1) - scores = ( - cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels).sigmoid() + strides = x.new_full((x.shape[0],), stride) + proiors = torch.stack([x, y, strides, strides], dim=-1) + mlvl_center_priors.append(proiors.unsqueeze(0).repeat(b, 1, 1)) + + center_priors = torch.cat(mlvl_center_priors, dim=1) + dis_preds = self.distribution_project(reg_preds) * center_priors[..., 2, None] + bboxes = distance2bbox(center_priors[..., :2], dis_preds, max_shape=input_shape) + scores = cls_preds.sigmoid() + result_list = [] + for i in range(b): + # add a dummy background class at the end of all labels + # same with mmdetection2.0 + score, bbox = scores[i], bboxes[i] + padding = score.new_zeros(score.shape[0], 1) + score = torch.cat([score, padding], dim=1) + results = multiclass_nms( + bbox, + score, + score_thr=0.05, + nms_cfg=dict(type="nms", iou_threshold=0.6), + max_num=100, ) - bbox_pred = bbox_pred.permute(1, 2, 0) - bbox_pred = self.distribution_project(bbox_pred) * stride - - nms_pre = 1000 - if scores.shape[0] > nms_pre: - max_scores, _ = scores.max(dim=1) - _, topk_inds = max_scores.topk(nms_pre) - center_points = center_points[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] - - bboxes = distance2bbox(center_points, bbox_pred, max_shape=img_shape) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - - mlvl_bboxes = torch.cat(mlvl_bboxes) - if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - - mlvl_scores = torch.cat(mlvl_scores) - # add a dummy background class at the end of all labels - # same with mmdetection2.0 - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - - det_bboxes, det_labels = multiclass_nms( - mlvl_bboxes, - mlvl_scores, - score_thr=0.05, - nms_cfg=dict(type="nms", iou_threshold=0.6), - max_num=100, - ) - return det_bboxes, det_labels + result_list.append(results) + return result_list def get_single_level_center_point( self, featmap_size, stride, dtype, device, flatten=True diff --git a/nanodet/model/head/nanodet_head.py b/nanodet/model/head/nanodet_head.py index 7aaf7d615..95eb43fd4 100755 --- a/nanodet/model/head/nanodet_head.py +++ b/nanodet/model/head/nanodet_head.py @@ -15,8 +15,6 @@ import torch import torch.nn as nn -from nanodet.util import multi_apply - from ..module.conv import ConvModule, DepthwiseConvModule from ..module.init_weights import normal_init from .gfl_head import GFLHead @@ -138,38 +136,22 @@ def init_weights(self): print("Finish initialize NanoDet Head.") def forward(self, feats): - return multi_apply( - self.forward_single, - feats, - self.cls_convs, - self.reg_convs, - self.gfl_cls, - self.gfl_reg, - ) - - def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg): - cls_feat = x - reg_feat = x - for cls_conv in cls_convs: - cls_feat = cls_conv(cls_feat) - for reg_conv in reg_convs: - reg_feat = reg_conv(reg_feat) - if self.share_cls_reg: - feat = gfl_cls(cls_feat) - cls_score, bbox_pred = torch.split( - feat, [self.cls_out_channels, 4 * (self.reg_max + 1)], dim=1 - ) - else: - cls_score = gfl_cls(cls_feat) - bbox_pred = gfl_reg(reg_feat) - - if torch.onnx.is_in_onnx_export(): - cls_score = ( - torch.sigmoid(cls_score) - .reshape(1, self.num_classes, -1) - .permute(0, 2, 1) - ) - bbox_pred = bbox_pred.reshape(1, (self.reg_max + 1) * 4, -1).permute( - 0, 2, 1 - ) - return cls_score, bbox_pred + outputs = [] + for x, cls_convs, reg_convs, gfl_cls, gfl_reg in zip( + feats, self.cls_convs, self.reg_convs, self.gfl_cls, self.gfl_reg + ): + cls_feat = x + reg_feat = x + for cls_conv in cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in reg_convs: + reg_feat = reg_conv(reg_feat) + if self.share_cls_reg: + output = gfl_cls(cls_feat) + else: + cls_score = gfl_cls(cls_feat) + bbox_pred = gfl_reg(reg_feat) + output = torch.cat([cls_score, bbox_pred], dim=1) + outputs.append(output.flatten(start_dim=2)) + outputs = torch.cat(outputs, dim=2).permute(0, 2, 1) + return outputs diff --git a/nanodet/model/head/nanodet_plus_head.py b/nanodet/model/head/nanodet_plus_head.py new file mode 100644 index 000000000..6276e7ec4 --- /dev/null +++ b/nanodet/model/head/nanodet_plus_head.py @@ -0,0 +1,497 @@ +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn + +from nanodet.util import bbox2distance, distance2bbox, multi_apply, overlay_bbox_cv + +from ...data.transform.warp import warp_boxes +from ..loss.gfocal_loss import DistributionFocalLoss, QualityFocalLoss +from ..loss.iou_loss import GIoULoss +from ..module.conv import ConvModule, DepthwiseConvModule +from ..module.init_weights import normal_init +from ..module.nms import multiclass_nms +from .assigner.dsl_assigner import DynamicSoftLabelAssigner +from .gfl_head import Integral, reduce_mean + + +class NanoDetPlusHead(nn.Module): + """Detection head used in NanoDet-Plus. + + Args: + num_classes (int): Number of categories excluding the background + category. + loss (dict): Loss config. + input_channel (int): Number of channels of the input feature. + feat_channels (int): Number of channels of the feature. + Default: 96. + stacked_convs (int): Number of conv layers in the stacked convs. + Default: 2. + kernel_size (int): Size of the convolving kernel. Default: 5. + strides (list[int]): Strides of input multi-level feature maps. + Default: [8, 16, 32]. + conv_type (str): Type of the convolution. + Default: "DWConv". + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + reg_max (int): The maximal value of the discrete set. Default: 7. + activation (str): Type of activation function. Default: "LeakyReLU". + assigner_cfg (dict): Config dict of the assigner. Default: dict(topk=13). + """ + + def __init__( + self, + num_classes, + loss, + input_channel, + feat_channels=96, + stacked_convs=2, + kernel_size=5, + strides=[8, 16, 32], + conv_type="DWConv", + norm_cfg=dict(type="BN"), + reg_max=7, + activation="LeakyReLU", + assigner_cfg=dict(topk=13), + **kwargs + ): + super(NanoDetPlusHead, self).__init__() + self.num_classes = num_classes + self.in_channels = input_channel + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.kernel_size = kernel_size + self.strides = strides + self.reg_max = reg_max + self.activation = activation + self.ConvModule = ConvModule if conv_type == "Conv" else DepthwiseConvModule + + self.loss_cfg = loss + self.norm_cfg = norm_cfg + + self.assigner = DynamicSoftLabelAssigner(**assigner_cfg) + self.distribution_project = Integral(self.reg_max) + + self.loss_qfl = QualityFocalLoss( + beta=self.loss_cfg.loss_qfl.beta, + loss_weight=self.loss_cfg.loss_qfl.loss_weight, + ) + self.loss_dfl = DistributionFocalLoss( + loss_weight=self.loss_cfg.loss_dfl.loss_weight + ) + self.loss_bbox = GIoULoss(loss_weight=self.loss_cfg.loss_bbox.loss_weight) + self._init_layers() + self.init_weights() + + def _init_layers(self): + self.cls_convs = nn.ModuleList() + for _ in self.strides: + cls_convs = self._buid_not_shared_head() + self.cls_convs.append(cls_convs) + + self.gfl_cls = nn.ModuleList( + [ + nn.Conv2d( + self.feat_channels, + self.num_classes + 4 * (self.reg_max + 1), + 1, + padding=0, + ) + for _ in self.strides + ] + ) + + def _buid_not_shared_head(self): + cls_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + self.ConvModule( + chn, + self.feat_channels, + self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None, + activation=self.activation, + ) + ) + return cls_convs + + def init_weights(self): + for m in self.cls_convs.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.01) + # init cls head with confidence = 0.01 + bias_cls = -4.595 + for i in range(len(self.strides)): + normal_init(self.gfl_cls[i], std=0.01, bias=bias_cls) + print("Finish initialize NanoDet-Plus Head.") + + def forward(self, feats): + outputs = [] + for feat, cls_convs, gfl_cls in zip( + feats, + self.cls_convs, + self.gfl_cls, + ): + for conv in cls_convs: + feat = conv(feat) + output = gfl_cls(feat) + outputs.append(output.flatten(start_dim=2)) + outputs = torch.cat(outputs, dim=2).permute(0, 2, 1) + return outputs + + def loss(self, preds, gt_meta, aux_preds=None): + """Compute losses. + Args: + preds (Tensor): Prediction output. + gt_meta (dict): Ground truth information. + aux_preds (tuple[Tensor], optional): Auxiliary head prediction output. + + Returns: + loss (Tensor): Loss tensor. + loss_states (dict): State dict of each loss. + """ + gt_bboxes = gt_meta["gt_bboxes"] + gt_labels = gt_meta["gt_labels"] + device = preds[0].device + batch_size = preds[0].shape[0] + input_height, input_width = gt_meta["img"].shape[2:] + featmap_sizes = [ + (math.ceil(input_height / stride), math.ceil(input_width) / stride) + for stride in self.strides + ] + # get grid cells of one image + mlvl_center_priors = [ + self.get_single_level_center_priors( + batch_size, + featmap_sizes[i], + stride, + dtype=torch.float32, + device=device, + ) + for i, stride in enumerate(self.strides) + ] + center_priors = torch.cat(mlvl_center_priors, dim=1) + + cls_preds, reg_preds = preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1 + ) + dis_preds = self.distribution_project(reg_preds) * center_priors[..., 2, None] + decoded_bboxes = distance2bbox(center_priors[..., :2], dis_preds) + + if aux_preds is not None: + # use auxiliary head to assign + aux_cls_preds, aux_reg_preds = aux_preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1 + ) + aux_dis_preds = ( + self.distribution_project(aux_reg_preds) * center_priors[..., 2, None] + ) + aux_decoded_bboxes = distance2bbox(center_priors[..., :2], aux_dis_preds) + batch_assign_res = multi_apply( + self.target_assign_single_img, + aux_cls_preds.detach(), + center_priors, + aux_decoded_bboxes.detach(), + gt_bboxes, + gt_labels, + ) + else: + # use self prediction to assign + batch_assign_res = multi_apply( + self.target_assign_single_img, + cls_preds.detach(), + center_priors, + decoded_bboxes.detach(), + gt_bboxes, + gt_labels, + ) + + loss, loss_states = self._get_loss_from_assign( + cls_preds, reg_preds, decoded_bboxes, batch_assign_res + ) + + if aux_preds is not None: + aux_loss, aux_loss_states = self._get_loss_from_assign( + aux_cls_preds, aux_reg_preds, aux_decoded_bboxes, batch_assign_res + ) + loss = loss + aux_loss + for k, v in aux_loss_states.items(): + loss_states["aux_" + k] = v + return loss, loss_states + + def _get_loss_from_assign(self, cls_preds, reg_preds, decoded_bboxes, assign): + device = cls_preds.device + labels, label_scores, bbox_targets, dist_targets, num_pos = assign + num_total_samples = max( + reduce_mean(torch.tensor(sum(num_pos)).to(device)).item(), 1.0 + ) + + labels = torch.cat(labels, dim=0) + label_scores = torch.cat(label_scores, dim=0) + bbox_targets = torch.cat(bbox_targets, dim=0) + cls_preds = cls_preds.reshape(-1, self.num_classes) + reg_preds = reg_preds.reshape(-1, 4 * (self.reg_max + 1)) + decoded_bboxes = decoded_bboxes.reshape(-1, 4) + loss_qfl = self.loss_qfl( + cls_preds, (labels, label_scores), avg_factor=num_total_samples + ) + + pos_inds = torch.nonzero( + (labels >= 0) & (labels < self.num_classes), as_tuple=False + ).squeeze(1) + + if len(pos_inds) > 0: + weight_targets = cls_preds[pos_inds].detach().sigmoid().max(dim=1)[0] + bbox_avg_factor = max(reduce_mean(weight_targets.sum()).item(), 1.0) + + loss_bbox = self.loss_bbox( + decoded_bboxes[pos_inds], + bbox_targets[pos_inds], + weight=weight_targets, + avg_factor=bbox_avg_factor, + ) + + dist_targets = torch.cat(dist_targets, dim=0) + loss_dfl = self.loss_dfl( + reg_preds[pos_inds].reshape(-1, self.reg_max + 1), + dist_targets[pos_inds].reshape(-1), + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0 * bbox_avg_factor, + ) + else: + loss_bbox = reg_preds.sum() * 0 + loss_dfl = reg_preds.sum() * 0 + + loss = loss_qfl + loss_bbox + loss_dfl + loss_states = dict(loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl) + return loss, loss_states + + @torch.no_grad() + def target_assign_single_img( + self, cls_preds, center_priors, decoded_bboxes, gt_bboxes, gt_labels + ): + """Compute classification, regression, and objectness targets for + priors in a single image. + Args: + cls_preds (Tensor): Classification predictions of one image, + a 2D-Tensor with shape [num_priors, num_classes] + center_priors (Tensor): All priors of one image, a 2D-Tensor with + shape [num_priors, 4] in [cx, xy, stride_w, stride_y] format. + decoded_bboxes (Tensor): Decoded bboxes predictions of one image, + a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y, + br_x, br_y] format. + gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor + with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth labels of one image, a Tensor + with shape [num_gts]. + """ + + num_priors = center_priors.size(0) + device = center_priors.device + gt_bboxes = torch.from_numpy(gt_bboxes).to(device) + gt_labels = torch.from_numpy(gt_labels).to(device) + num_gts = gt_labels.size(0) + gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype) + + bbox_targets = torch.zeros_like(center_priors) + dist_targets = torch.zeros_like(center_priors) + labels = center_priors.new_full( + (num_priors,), self.num_classes, dtype=torch.long + ) + label_scores = center_priors.new_zeros(labels.shape, dtype=torch.float) + # No target + if num_gts == 0: + return labels, label_scores, bbox_targets, dist_targets, 0 + + assign_result = self.assigner.assign( + cls_preds.sigmoid(), center_priors, decoded_bboxes, gt_bboxes, gt_labels + ) + pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample( + assign_result, gt_bboxes + ) + num_pos_per_img = pos_inds.size(0) + pos_ious = assign_result.max_overlaps[pos_inds] + + if len(pos_inds) > 0: + bbox_targets[pos_inds, :] = pos_gt_bboxes + dist_targets[pos_inds, :] = ( + bbox2distance(center_priors[pos_inds, :2], pos_gt_bboxes) + / center_priors[pos_inds, None, 2] + ) + dist_targets = dist_targets.clamp(min=0, max=self.reg_max - 0.1) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] + label_scores[pos_inds] = pos_ious + return ( + labels, + label_scores, + bbox_targets, + dist_targets, + num_pos_per_img, + ) + + def sample(self, assign_result, gt_bboxes): + """Sample positive and negative bboxes.""" + pos_inds = ( + torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) + .squeeze(-1) + .unique() + ) + neg_inds = ( + torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) + .squeeze(-1) + .unique() + ) + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert pos_assigned_gt_inds.numel() == 0 + pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] + return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds + + def post_process(self, preds, meta): + """Prediction results post processing. Decode bboxes and rescale + to original image size. + Args: + preds (Tensor): Prediction output. + meta (dict): Meta info. + """ + cls_scores, bbox_preds = preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1 + ) + result_list = self.get_bboxes(cls_scores, bbox_preds, meta) + det_results = {} + warp_matrixes = ( + meta["warp_matrix"] + if isinstance(meta["warp_matrix"], list) + else meta["warp_matrix"] + ) + img_heights = ( + meta["img_info"]["height"].cpu().numpy() + if isinstance(meta["img_info"]["height"], torch.Tensor) + else meta["img_info"]["height"] + ) + img_widths = ( + meta["img_info"]["width"].cpu().numpy() + if isinstance(meta["img_info"]["width"], torch.Tensor) + else meta["img_info"]["width"] + ) + img_ids = ( + meta["img_info"]["id"].cpu().numpy() + if isinstance(meta["img_info"]["id"], torch.Tensor) + else meta["img_info"]["id"] + ) + + for result, img_width, img_height, img_id, warp_matrix in zip( + result_list, img_widths, img_heights, img_ids, warp_matrixes + ): + det_result = {} + det_bboxes, det_labels = result + det_bboxes = det_bboxes.cpu().numpy() + det_bboxes[:, :4] = warp_boxes( + det_bboxes[:, :4], np.linalg.inv(warp_matrix), img_width, img_height + ) + classes = det_labels.cpu().numpy() + for i in range(self.num_classes): + inds = classes == i + det_result[i] = np.concatenate( + [ + det_bboxes[inds, :4].astype(np.float32), + det_bboxes[inds, 4:5].astype(np.float32), + ], + axis=1, + ).tolist() + det_results[img_id] = det_result + return det_results + + def show_result( + self, img, dets, class_names, score_thres=0.3, show=True, save_path=None + ): + result = overlay_bbox_cv(img, dets, class_names, score_thresh=score_thres) + if show: + cv2.imshow("det", result) + return result + + def get_bboxes(self, cls_preds, reg_preds, img_metas): + """Decode the outputs to bboxes. + Args: + cls_preds (Tensor): Shape (num_imgs, num_points, num_classes). + reg_preds (Tensor): Shape (num_imgs, num_points, 4 * (regmax + 1)). + img_metas (dict): Dict of image info. + + Returns: + results_list (list[tuple]): List of detection bboxes and labels. + """ + device = cls_preds.device + b = cls_preds.shape[0] + input_height, input_width = img_metas["img"].shape[2:] + input_shape = (input_height, input_width) + + featmap_sizes = [ + (math.ceil(input_height / stride), math.ceil(input_width) / stride) + for stride in self.strides + ] + # get grid cells of one image + mlvl_center_priors = [ + self.get_single_level_center_priors( + b, + featmap_sizes[i], + stride, + dtype=torch.float32, + device=device, + ) + for i, stride in enumerate(self.strides) + ] + center_priors = torch.cat(mlvl_center_priors, dim=1) + dis_preds = self.distribution_project(reg_preds) * center_priors[..., 2, None] + bboxes = distance2bbox(center_priors[..., :2], dis_preds, max_shape=input_shape) + scores = cls_preds.sigmoid() + result_list = [] + for i in range(b): + # add a dummy background class at the end of all labels + # same with mmdetection2.0 + score, bbox = scores[i], bboxes[i] + padding = score.new_zeros(score.shape[0], 1) + score = torch.cat([score, padding], dim=1) + results = multiclass_nms( + bbox, + score, + score_thr=0.05, + nms_cfg=dict(type="nms", iou_threshold=0.6), + max_num=100, + ) + result_list.append(results) + return result_list + + def get_single_level_center_priors( + self, batch_size, featmap_size, stride, dtype, device + ): + """Generate centers of a single stage feature map. + Args: + batch_size (int): Number of images in one batch. + featmap_size (tuple[int]): height and width of the feature map + stride (int): down sample stride of the feature map + dtype (obj:`torch.dtype`): data type of the tensors + device (obj:`torch.device`): device of the tensors + Return: + priors (Tensor): center priors of a single level feature map. + """ + h, w = featmap_size + x_range = (torch.arange(w, dtype=dtype, device=device)) * stride + y_range = (torch.arange(h, dtype=dtype, device=device)) * stride + y, x = torch.meshgrid(y_range, x_range) + y = y.flatten() + x = x.flatten() + strides = x.new_full((x.shape[0],), stride) + proiors = torch.stack([x, y, strides, strides], dim=-1) + return proiors.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/nanodet/util/box_transform.py b/nanodet/util/box_transform.py index d9fa9ccad..4b82a8c19 100644 --- a/nanodet/util/box_transform.py +++ b/nanodet/util/box_transform.py @@ -13,10 +13,10 @@ def distance2bbox(points, distance, max_shape=None): Returns: Tensor: Decoded bboxes. """ - x1 = points[:, 0] - distance[:, 0] - y1 = points[:, 1] - distance[:, 1] - x2 = points[:, 0] + distance[:, 2] - y2 = points[:, 1] + distance[:, 3] + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] if max_shape is not None: x1 = x1.clamp(min=0, max=max_shape[1]) y1 = y1.clamp(min=0, max=max_shape[0]) diff --git a/nanodet/util/check_point.py b/nanodet/util/check_point.py index fe6a8e63b..937c1fc32 100644 --- a/nanodet/util/check_point.py +++ b/nanodet/util/check_point.py @@ -22,7 +22,7 @@ def load_model_weight(model, checkpoint, logger): - state_dict = checkpoint["state_dict"] + state_dict = checkpoint["state_dict"].copy() # strip prefix of state_dict if list(state_dict.keys())[0].startswith("module."): state_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()} diff --git a/nanodet/util/logger.py b/nanodet/util/logger.py index d28dea193..a871e5055 100644 --- a/nanodet/util/logger.py +++ b/nanodet/util/logger.py @@ -195,6 +195,10 @@ def _init_logger(self): def info(self, string): self.logger.info(string) + @rank_zero_only + def log(self, string): + self.logger.info(string) + @rank_zero_only def dump_cfg(self, cfg_node): with open(os.path.join(self.log_dir, "train_cfg.yml"), "w") as f: diff --git a/tests/test_models/test_head/test_gfl_head.py b/tests/test_models/test_head/test_gfl_head.py index eb51eadb0..aa1988627 100644 --- a/tests/test_models/test_head/test_gfl_head.py +++ b/tests/test_models/test_head/test_gfl_head.py @@ -28,7 +28,11 @@ def test_gfl_head_loss(): preds = head.forward(feat) # Test that empty ground truth encourages the network to predict background - meta = dict(gt_bboxes=[np.random.random((0, 4))], gt_labels=[np.array([])]) + meta = dict( + img=torch.rand((2, 3, 64, 64)), + gt_bboxes=[np.random.random((0, 4))], + gt_labels=[np.array([])], + ) loss, empty_gt_losses = head.loss(preds, meta) # When there is no truth, the cls loss should be nonzero but there should # be no box loss. @@ -49,7 +53,9 @@ def test_gfl_head_loss(): np.array([[23.6667, 23.8757, 238.6326, 151.8874]], dtype=np.float32), ] gt_labels = [np.array([2])] - meta = dict(gt_bboxes=gt_bboxes, gt_labels=gt_labels) + meta = dict( + img=torch.rand((2, 3, 64, 64)), gt_bboxes=gt_bboxes, gt_labels=gt_labels + ) loss, one_gt_losses = head.loss(preds, meta) onegt_qfl_loss = one_gt_losses["loss_qfl"] onegt_box_loss = one_gt_losses["loss_bbox"] diff --git a/tests/test_models/test_head/test_nanodet_head.py b/tests/test_models/test_head/test_nanodet_head.py index 5cf1e17c9..d2a3f7981 100644 --- a/tests/test_models/test_head/test_nanodet_head.py +++ b/tests/test_models/test_head/test_nanodet_head.py @@ -27,10 +27,9 @@ def test_gfl_head_loss(): head = build_head(cfg) feat = [torch.rand(1, 1, 320 // stride, 320 // stride) for stride in [8, 16, 32]] - cls_preds, reg_preds = head.forward(feat) - for cls, reg, stride in zip(cls_preds, reg_preds, [8, 16, 32]): - assert cls.shape == (1, 80, 320 // stride, 320 // stride) - assert reg.shape == (1, (8 + 1) * 4, 320 // stride, 320 // stride) + preds = head.forward(feat) + num_points = sum([(320 // stride) ** 2 for stride in [8, 16, 32]]) + assert preds.shape == (1, num_points, 80 + (8 + 1) * 4) head_cfg = dict( name="NanoDetHead", @@ -53,7 +52,6 @@ def test_gfl_head_loss(): cfg = CfgNode(head_cfg) head = build_head(cfg) - cls_preds, reg_preds = head.forward(feat) - for cls, reg, stride in zip(cls_preds, reg_preds, [8, 16, 32]): - assert cls.shape == (1, 20, 320 // stride, 320 // stride) - assert reg.shape == (1, (5 + 1) * 4, 320 // stride, 320 // stride) + preds = head.forward(feat) + num_points = sum([(320 // stride) ** 2 for stride in [8, 16, 32]]) + assert preds.shape == (1, num_points, 20 + (5 + 1) * 4) diff --git a/tests/test_models/test_head/test_nanodet_plus_head.py b/tests/test_models/test_head/test_nanodet_plus_head.py new file mode 100644 index 000000000..eb529f575 --- /dev/null +++ b/tests/test_models/test_head/test_nanodet_plus_head.py @@ -0,0 +1,117 @@ +import numpy as np +import torch + +from nanodet.model.head import build_head +from nanodet.util.yacs import CfgNode + + +def test_nanodet_plus_head_loss(): + head_cfg = dict( + name="NanoDetPlusHead", + num_classes=80, + input_channel=1, + feat_channels=96, + stacked_convs=2, + conv_type="DWConv", + reg_max=8, + strides=[8, 16, 32], + loss=dict( + loss_qfl=dict( + name="QualityFocalLoss", use_sigmoid=True, beta=2.0, loss_weight=1.0 + ), + loss_dfl=dict(name="DistributionFocalLoss", loss_weight=0.25), + loss_bbox=dict(name="GIoULoss", loss_weight=2.0), + ), + ) + cfg = CfgNode(head_cfg) + + head = build_head(cfg) + feat = [torch.rand(1, 1, 320 // stride, 320 // stride) for stride in [8, 16, 32]] + + preds = head.forward(feat) + num_points = sum([(320 // stride) ** 2 for stride in [8, 16, 32]]) + assert preds.shape == (1, num_points, 80 + (8 + 1) * 4) + + head_cfg = dict( + name="NanoDetPlusHead", + num_classes=20, + input_channel=1, + feat_channels=96, + stacked_convs=2, + conv_type="Conv", + reg_max=5, + share_cls_reg=False, + strides=[8, 16, 32], + loss=dict( + loss_qfl=dict( + name="QualityFocalLoss", use_sigmoid=True, beta=2.0, loss_weight=1.0 + ), + loss_dfl=dict(name="DistributionFocalLoss", loss_weight=0.25), + loss_bbox=dict(name="GIoULoss", loss_weight=2.0), + ), + ) + cfg = CfgNode(head_cfg) + head = build_head(cfg) + + preds = head.forward(feat) + num_points = sum([(320 // stride) ** 2 for stride in [8, 16, 32]]) + assert preds.shape == (1, num_points, 20 + (5 + 1) * 4) + + # Test that empty ground truth encourages the network to predict background + meta = dict( + img=torch.rand((1, 3, 320, 320)), + gt_bboxes=[np.random.random((0, 4))], + gt_labels=[np.array([])], + ) + loss, empty_gt_losses = head.loss(preds, meta) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_qfl_loss = empty_gt_losses["loss_qfl"] + empty_box_loss = empty_gt_losses["loss_bbox"] + empty_dfl_loss = empty_gt_losses["loss_dfl"] + assert empty_qfl_loss.item() > 0 + assert ( + empty_box_loss.item() == 0 + ), "there should be no box loss when there are no true boxes" + assert ( + empty_dfl_loss.item() == 0 + ), "there should be no dfl loss when there are no true boxes" + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + np.array([[23.6667, 23.8757, 238.6326, 151.8874]], dtype=np.float32), + ] + gt_labels = [np.array([2])] + meta = dict( + img=torch.rand((1, 3, 320, 320)), gt_bboxes=gt_bboxes, gt_labels=gt_labels + ) + loss, one_gt_losses = head.loss(preds, meta) + onegt_qfl_loss = one_gt_losses["loss_qfl"] + onegt_box_loss = one_gt_losses["loss_bbox"] + onegt_dfl_loss = one_gt_losses["loss_dfl"] + assert onegt_qfl_loss.item() > 0, "qfl loss should be non-zero" + assert onegt_box_loss.item() > 0, "box loss should be non-zero" + assert onegt_dfl_loss.item() > 0, "dfl loss should be non-zero" + + # test aux input + gt_bboxes = [ + np.array([[23.6667, 23.8757, 238.6326, 151.8874]], dtype=np.float32), + ] + gt_labels = [np.array([2])] + meta = dict( + img=torch.rand((1, 3, 320, 320)), gt_bboxes=gt_bboxes, gt_labels=gt_labels + ) + loss, one_gt_losses = head.loss(preds, meta, aux_preds=preds) + onegt_qfl_loss = one_gt_losses["loss_qfl"] + onegt_box_loss = one_gt_losses["loss_bbox"] + onegt_dfl_loss = one_gt_losses["loss_dfl"] + onegt_aux_qfl_loss = one_gt_losses["aux_loss_qfl"] + onegt_aux_box_loss = one_gt_losses["aux_loss_bbox"] + onegt_aux_dfl_loss = one_gt_losses["aux_loss_dfl"] + assert onegt_qfl_loss.item() > 0, "qfl loss should be non-zero" + assert onegt_box_loss.item() > 0, "box loss should be non-zero" + assert onegt_dfl_loss.item() > 0, "dfl loss should be non-zero" + assert onegt_aux_qfl_loss.item() > 0, "aux_qfl loss should be non-zero" + assert onegt_aux_box_loss.item() > 0, "aux_box loss should be non-zero" + assert onegt_aux_dfl_loss.item() > 0, "aux_dfl loss should be non-zero"