Skip to content

Commit

Permalink
[Feature] Support ignore boxes in nanodet head (#480)
Browse files Browse the repository at this point in the history
* [Feature] Support ignore boxes in nanodet head

* convert gt_bboxes_ignore to torch tensor

* add bboxes_ignore to nanodet plus head

* switch https://gitlab.com/PyCQA/flake8 for https://github.com/PyCQA/flake8

* modify unittest

* Reformat code

* add docstring and set default value to None
  • Loading branch information
zero0kiriyu committed Dec 22, 2022
1 parent ad410c2 commit d8ba391
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
hooks:
- id: black

- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
3 changes: 3 additions & 0 deletions docs/config_file_detail.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ head:
scales_per_octave: 1
strides: [8, 16, 32]
reg_max: 7
ignore_iof_thr: -1
norm_cfg:
type: BN
loss:
Expand All @@ -92,6 +93,8 @@ head:

`reg_max`: max value of per-level l-r-t-b distance

`ignore_iof_thr`: thresh of iof for ignore box, default value -1

`norm_cfg`: normalization layer setting

`loss`: adjust loss functions and weights
Expand Down
10 changes: 6 additions & 4 deletions nanodet/data/dataset/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,13 @@ def get_img_annotation(self, idx):
if self.use_keypoint:
gt_keypoints = []
for ann in anns:
if ann.get("ignore", False):
continue
x1, y1, w, h = ann["bbox"]
if ann["area"] <= 0 or w < 1 or h < 1:
continue
if ann["category_id"] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get("iscrowd", False):
if ann.get("iscrowd", False) or ann.get("ignore", False):
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
Expand Down Expand Up @@ -131,7 +129,11 @@ def get_train_data(self, idx):
raise FileNotFoundError("Cant load image! Please check image path!")
ann = self.get_img_annotation(idx)
meta = dict(
img=img, img_info=img_info, gt_bboxes=ann["bboxes"], gt_labels=ann["labels"]
img=img,
img_info=img_info,
gt_bboxes=ann["bboxes"],
gt_labels=ann["labels"],
gt_bboxes_ignore=ann["bboxes_ignore"],
)
if self.use_instance_mask:
meta["gt_masks"] = ann["masks"]
Expand Down
10 changes: 10 additions & 0 deletions nanodet/data/transform/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def warp_and_resize(
if "gt_bboxes" in meta:
boxes = meta["gt_bboxes"]
meta["gt_bboxes"] = warp_boxes(boxes, M, dst_shape[0], dst_shape[1])
if "gt_bboxes_ignore" in meta:
bboxes_ignore = meta["gt_bboxes_ignore"]
meta["gt_bboxes_ignore"] = warp_boxes(
bboxes_ignore, M, dst_shape[0], dst_shape[1]
)
if "gt_masks" in meta:
for i, mask in enumerate(meta["gt_masks"]):
meta["gt_masks"][i] = cv2.warpPerspective(mask, M, dsize=tuple(dst_shape))
Expand Down Expand Up @@ -343,6 +348,11 @@ def __call__(self, meta_data, dst_shape):
if "gt_bboxes" in meta_data:
boxes = meta_data["gt_bboxes"]
meta_data["gt_bboxes"] = warp_boxes(boxes, M, dst_shape[0], dst_shape[1])
if "gt_bboxes_ignore" in meta_data:
bboxes_ignore = meta_data["gt_bboxes_ignore"]
meta_data["gt_bboxes_ignore"] = warp_boxes(
bboxes_ignore, M, dst_shape[0], dst_shape[1]
)
if "gt_masks" in meta_data:
for i, mask in enumerate(meta_data["gt_masks"]):
meta_data["gt_masks"][i] = cv2.warpPerspective(
Expand Down
21 changes: 18 additions & 3 deletions nanodet/model/head/assigner/atss_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,21 @@
class ATSSAssigner(BaseAssigner):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `0` or a positive integer
Each proposals will be assigned with `-1`, `0` or a positive integer
indicating the ground truth index.
- -1: ignore sample, will be masked in loss calculation
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
topk (float): number of bbox selected in each level
ignore_iof_thr (float): whether ignore max overlaps or not.
Default -1 ([0,1] or -1).
"""

def __init__(self, topk):
def __init__(self, topk, ignore_iof_thr=-1):
self.topk = topk
self.ignore_iof_thr = ignore_iof_thr

# https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py

Expand Down Expand Up @@ -105,6 +108,18 @@ def assign(
(bboxes_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt()
)

if (
self.ignore_iof_thr > 0
and gt_bboxes_ignore is not None
and gt_bboxes_ignore.numel() > 0
and bboxes.numel() > 0
):
ignore_overlaps = bbox_overlaps(bboxes, gt_bboxes_ignore, mode="iof")
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
distances[ignore_idxs, :] = INF
assigned_gt_inds[ignore_idxs] = -1

# Selecting candidates based on the center distance
candidate_idxs = []
start_idx = 0
Expand Down
22 changes: 21 additions & 1 deletion nanodet/model/head/assigner/dsl_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ class DynamicSoftLabelAssigner(BaseAssigner):
topk (int): Select top-k predictions to calculate dynamic k
best matchs for each gt. Default 13.
iou_factor (float): The scale factor of iou cost. Default 3.0.
ignore_iof_thr (int): whether ignore max overlaps or not.
Default -1 (1 or -1).
"""

def __init__(self, topk=13, iou_factor=3.0):
def __init__(self, topk=13, iou_factor=3.0, ignore_iof_thr=-1):
self.topk = topk
self.iou_factor = iou_factor
self.ignore_iof_thr = ignore_iof_thr

def assign(
self,
Expand All @@ -27,6 +30,7 @@ def assign(
decoded_bboxes,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
):
"""Assign gt to priors with dynamic soft label assignment.
Args:
Expand All @@ -38,6 +42,8 @@ def assign(
[num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor): Ground truth labels of one image, a Tensor
with shape [num_gts].
Expand Down Expand Up @@ -113,6 +119,20 @@ def assign(
(num_bboxes,), -INF, dtype=torch.float32
)
max_overlaps[valid_mask] = matched_pred_ious

if (
self.ignore_iof_thr > 0
and gt_bboxes_ignore is not None
and gt_bboxes_ignore.numel() > 0
and num_bboxes > 0
):
ignore_overlaps = bbox_overlaps(
valid_decoded_bbox, gt_bboxes_ignore, mode="iof"
)
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
assigned_gt_inds[ignore_idxs] = -1

return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
)
Expand Down
9 changes: 7 additions & 2 deletions nanodet/model/head/gfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
conv_cfg=None,
norm_cfg=dict(type="GN", num_groups=32, requires_grad=True),
reg_max=16,
ignore_iof_thr=-1,
**kwargs
):
super(GFLHead, self).__init__()
Expand All @@ -120,12 +121,13 @@ def __init__(
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_sigmoid = self.loss_cfg.loss_qfl.use_sigmoid
self.ignore_iof_thr = ignore_iof_thr
if self.use_sigmoid:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1

self.assigner = ATSSAssigner(topk=9)
self.assigner = ATSSAssigner(topk=9, ignore_iof_thr=ignore_iof_thr)
self.distribution_project = Integral(self.reg_max)

self.loss_qfl = QualityFocalLoss(
Expand Down Expand Up @@ -209,9 +211,9 @@ def loss(self, preds, gt_meta):
)
device = cls_scores.device
gt_bboxes = gt_meta["gt_bboxes"]
gt_bboxes_ignore = gt_meta["gt_bboxes_ignore"]
gt_labels = gt_meta["gt_labels"]
input_height, input_width = gt_meta["img"].shape[2:]
gt_bboxes_ignore = None

featmap_sizes = [
(math.ceil(input_height / stride), math.ceil(input_width) / stride)
Expand Down Expand Up @@ -465,6 +467,9 @@ def target_assign_single_img(
gt_bboxes = torch.from_numpy(gt_bboxes).to(device)
gt_labels = torch.from_numpy(gt_labels).to(device)

if gt_bboxes_ignore is not None:
gt_bboxes_ignore = torch.from_numpy(gt_bboxes_ignore).to(device)

assign_result = self.assigner.assign(
grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore, gt_labels
)
Expand Down
69 changes: 53 additions & 16 deletions nanodet/model/head/nanodet_plus_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,15 @@ def loss(self, preds, gt_meta, aux_preds=None):
loss (Tensor): Loss tensor.
loss_states (dict): State dict of each loss.
"""
gt_bboxes = gt_meta["gt_bboxes"]
gt_labels = gt_meta["gt_labels"]
device = preds.device
batch_size = preds.shape[0]
gt_bboxes = gt_meta["gt_bboxes"]
gt_labels = gt_meta["gt_labels"]

gt_bboxes_ignore = gt_meta["gt_bboxes_ignore"]
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(batch_size)]

input_height, input_width = gt_meta["img"].shape[2:]
featmap_sizes = [
(math.ceil(input_height / stride), math.ceil(input_width) / stride)
Expand Down Expand Up @@ -202,6 +207,7 @@ def loss(self, preds, gt_meta, aux_preds=None):
aux_decoded_bboxes.detach(),
gt_bboxes,
gt_labels,
gt_bboxes_ignore,
)
else:
# use self prediction to assign
Expand All @@ -212,6 +218,7 @@ def loss(self, preds, gt_meta, aux_preds=None):
decoded_bboxes.detach(),
gt_bboxes,
gt_labels,
gt_bboxes_ignore,
)

loss, loss_states = self._get_loss_from_assign(
Expand All @@ -229,19 +236,30 @@ def loss(self, preds, gt_meta, aux_preds=None):

def _get_loss_from_assign(self, cls_preds, reg_preds, decoded_bboxes, assign):
device = cls_preds.device
labels, label_scores, bbox_targets, dist_targets, num_pos = assign
(
labels,
label_scores,
label_weights,
bbox_targets,
dist_targets,
num_pos,
) = assign
num_total_samples = max(
reduce_mean(torch.tensor(sum(num_pos)).to(device)).item(), 1.0
)

labels = torch.cat(labels, dim=0)
label_scores = torch.cat(label_scores, dim=0)
label_weights = torch.cat(label_weights, dim=0)
bbox_targets = torch.cat(bbox_targets, dim=0)
cls_preds = cls_preds.reshape(-1, self.num_classes)
reg_preds = reg_preds.reshape(-1, 4 * (self.reg_max + 1))
decoded_bboxes = decoded_bboxes.reshape(-1, 4)
loss_qfl = self.loss_qfl(
cls_preds, (labels, label_scores), avg_factor=num_total_samples
cls_preds,
(labels, label_scores),
weight=label_weights,
avg_factor=num_total_samples,
)

pos_inds = torch.nonzero(
Expand Down Expand Up @@ -276,7 +294,13 @@ def _get_loss_from_assign(self, cls_preds, reg_preds, decoded_bboxes, assign):

@torch.no_grad()
def target_assign_single_img(
self, cls_preds, center_priors, decoded_bboxes, gt_bboxes, gt_labels
self,
cls_preds,
center_priors,
decoded_bboxes,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
):
"""Compute classification, regression, and objectness targets for
priors in a single image.
Expand All @@ -292,31 +316,40 @@ def target_assign_single_img(
with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth labels of one image, a Tensor
with shape [num_gts].
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
"""

num_priors = center_priors.size(0)
device = center_priors.device
gt_bboxes = torch.from_numpy(gt_bboxes).to(device)
gt_labels = torch.from_numpy(gt_labels).to(device)
num_gts = gt_labels.size(0)
gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype)

if gt_bboxes_ignore is not None:
gt_bboxes_ignore = torch.from_numpy(gt_bboxes_ignore).to(device)
gt_bboxes_ignore = gt_bboxes_ignore.to(decoded_bboxes.dtype)

assign_result = self.assigner.assign(
cls_preds.sigmoid(),
center_priors,
decoded_bboxes,
gt_bboxes,
gt_labels,
gt_bboxes_ignore,
)
pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample(
assign_result, gt_bboxes
)

num_priors = center_priors.size(0)
bbox_targets = torch.zeros_like(center_priors)
dist_targets = torch.zeros_like(center_priors)
labels = center_priors.new_full(
(num_priors,), self.num_classes, dtype=torch.long
)
label_weights = center_priors.new_zeros(num_priors, dtype=torch.float)
label_scores = center_priors.new_zeros(labels.shape, dtype=torch.float)
# No target
if num_gts == 0:
return labels, label_scores, bbox_targets, dist_targets, 0

assign_result = self.assigner.assign(
cls_preds.sigmoid(), center_priors, decoded_bboxes, gt_bboxes, gt_labels
)
pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample(
assign_result, gt_bboxes
)
num_pos_per_img = pos_inds.size(0)
pos_ious = assign_result.max_overlaps[pos_inds]

Expand All @@ -329,9 +362,13 @@ def target_assign_single_img(
dist_targets = dist_targets.clamp(min=0, max=self.reg_max - 0.1)
labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
label_scores[pos_inds] = pos_ious
label_weights[pos_inds] = 1.0
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
return (
labels,
label_scores,
label_weights,
bbox_targets,
dist_targets,
num_pos_per_img,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_models/test_head/test_gfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_gfl_head_loss():
meta = dict(
img=torch.rand((2, 3, 64, 64)),
gt_bboxes=[np.random.random((0, 4))],
gt_bboxes_ignore=[np.random.random((0, 4))],
gt_labels=[np.array([])],
)
loss, empty_gt_losses = head.loss(preds, meta)
Expand All @@ -52,9 +53,15 @@ def test_gfl_head_loss():
gt_bboxes = [
np.array([[23.6667, 23.8757, 238.6326, 151.8874]], dtype=np.float32),
]
gt_bboxes_ignore = [
np.array([[29.6667, 29.8757, 244.6326, 160.8874]], dtype=np.float32),
]
gt_labels = [np.array([2])]
meta = dict(
img=torch.rand((2, 3, 64, 64)), gt_bboxes=gt_bboxes, gt_labels=gt_labels
img=torch.rand((2, 3, 64, 64)),
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_bboxes_ignore=gt_bboxes_ignore,
)
loss, one_gt_losses = head.loss(preds, meta)
onegt_qfl_loss = one_gt_losses["loss_qfl"]
Expand Down
Loading

0 comments on commit d8ba391

Please sign in to comment.