-
Notifications
You must be signed in to change notification settings - Fork 9.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CodeCamp2023-503] Add the DDQ algorithm to mmdetection (#10772)
- Loading branch information
1 parent
60b29b3
commit 2dbf307
Showing
24 changed files
with
2,774 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# DDQ | ||
|
||
[Dense Distinct Query for End-to-End Object Detection](https://arxiv.org/abs/2303.12776) | ||
|
||
## Abstract | ||
|
||
One-to-one label assignment in object detection has successfully obviated the need for non-maximum suppression (NMS) as postprocessing and makes the pipeline end-to-end. However, it triggers a new dilemma as the widely used sparse queries cannot guarantee a high recall, while dense queries inevitably bring more similar queries and encounter optimization difficulties. As both sparse and dense queries are problematic, then what are the expected queries in end-to-end object detection? This paper shows that the solution should be Dense Distinct Queries (DDQ). Concretely, we first lay dense queries like traditional detectors and then select distinct ones for one-to-one assignments. DDQ blends the advantages of traditional and recent end-to-end detectors and significantly improves the performance of various detectors including FCN, R-CNN, and DETRs. Most impressively, DDQ-DETR achieves 52.1 AP on MS-COCO dataset within 12 epochs using a ResNet-50 backbone, outperforming all existing detectors in the same setting. DDQ also shares the benefit of end-to-end detectors in crowded scenes and achieves 93.8 AP on CrowdHuman. We hope DDQ can inspire researchers to consider the complementarity between traditional methods and end-to-end detectors. | ||
|
||
![ddq_arch](https://github.com/open-mmlab/mmdetection/assets/33146359/5ca9f11b-b6f3-454f-a2d1-3009ee337bbc) | ||
|
||
## Results and Models | ||
|
||
| Model | Backbone | Lr schd | Augmentation | box AP(val) | Config | Download | | ||
| :-------------: | :------: | :-----: | :----------: | :---------: | :------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| DDQ DETR-4scale | R-50 | 12e | DETR | 51.4 | [config](./ddq-detr-4scale_r50_8xb2-12e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq-detr-4scale_r50_8xb2-12e_coco/ddq-detr-4scale_r50_8xb2-12e_coco_20230809_170711-42528127.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq-detr-4scale_r50_8xb2-12e_coco/ddq-detr-4scale_r50_8xb2-12e_coco_20230809_170711.log.json) | | ||
| DDQ DETR-5scale | R-50 | 12e | DETR | 52.1 | [config](./ddq-detr-5scale_r50_8xb2-12e_coco.py) | [model\*](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_5scale_coco_1x.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_5scale_coco_1x_20230319_103307.log) | | ||
| DDQ DETR-4scale | Swin-L | 30e | DETR | 58.7 | [config](./ddq-detr-4scale_swinl_8xb2-30e_coco.py) | [model\*](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_swinl_30e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_swinl_30e_20230316_221721_20230318_143554.log) | | ||
|
||
**Note:** Models labeled "\*" are not trained by us, but from [DDQ official website](https://github.com/jshilong/DDQ). | ||
|
||
## Citation | ||
|
||
We provide the config files for DDQ: [Dense Distinct Query for End-to-End Object Detection](https://arxiv.org/abs/2303.12776). | ||
|
||
```latex | ||
@InProceedings{Zhang_2023_CVPR, | ||
author = {Zhang, Shilong and Wang, Xinjiang and Wang, Jiaqi and Pang, Jiangmiao and Lyu, Chengqi and Zhang, Wenwei and Luo, Ping and Chen, Kai}, | ||
title = {Dense Distinct Query for End-to-End Object Detection}, | ||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
month = {June}, | ||
year = {2023}, | ||
pages = {7329-7338} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py' | ||
] | ||
model = dict( | ||
type='DDQDETR', | ||
num_queries=900, # num_matching_queries | ||
# ratio of num_dense queries to num_queries | ||
dense_topk_ratio=1.5, | ||
with_box_refine=True, | ||
as_two_stage=True, | ||
data_preprocessor=dict( | ||
type='DetDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_size_divisor=1), | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=False), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision:https://resnet50')), | ||
neck=dict( | ||
type='ChannelMapper', | ||
in_channels=[512, 1024, 2048], | ||
kernel_size=1, | ||
out_channels=256, | ||
act_cfg=None, | ||
norm_cfg=dict(type='GN', num_groups=32), | ||
num_outs=4), | ||
# encoder class name: DeformableDetrTransformerEncoder | ||
encoder=dict( | ||
num_layers=6, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict(embed_dims=256, num_levels=4, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, # 1024 for DeformDETR | ||
ffn_drop=0.0))), # 0.1 for DeformDETR | ||
# decoder class name: DDQTransformerDecoder | ||
decoder=dict( | ||
# `num_layers` >= 2, because attention masks of the last | ||
# `num_layers` - 1 layers are used for distinct query selection | ||
num_layers=6, | ||
return_intermediate=True, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict(embed_dims=256, num_heads=8, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
cross_attn_cfg=dict(embed_dims=256, num_levels=4, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, # 1024 for DeformDETR | ||
ffn_drop=0.0)), # 0.1 for DeformDETR | ||
post_norm_cfg=None), | ||
positional_encoding=dict( | ||
num_feats=128, | ||
normalize=True, | ||
offset=0.0, # -0.5 for DeformDETR | ||
temperature=20), # 10000 for DeformDETR | ||
bbox_head=dict( | ||
type='DDQDETRHead', | ||
num_classes=80, | ||
sync_cls_avg_factor=True, | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=5.0), | ||
loss_iou=dict(type='GIoULoss', loss_weight=2.0)), | ||
dn_cfg=dict( | ||
label_noise_scale=0.5, | ||
box_noise_scale=1.0, | ||
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)), | ||
dqs_cfg=dict(type='nms', iou_threshold=0.8), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='FocalLossCost', weight=2.0), | ||
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), | ||
dict(type='IoUCost', iou_mode='giou', weight=2.0) | ||
])), | ||
test_cfg=dict(max_per_img=300)) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=_base_.backend_args), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict( | ||
type='RandomChoice', | ||
transforms=[ | ||
[ | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), | ||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), | ||
(736, 1333), (768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
], | ||
[ | ||
dict( | ||
type='RandomChoiceResize', | ||
# The radio of all image in train dataset < 7 | ||
# follow the original implement | ||
scales=[(400, 4200), (500, 4200), (600, 4200)], | ||
keep_ratio=True), | ||
dict( | ||
type='RandomCrop', | ||
crop_type='absolute_range', | ||
crop_size=(384, 600), | ||
allow_negative_crop=True), | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), | ||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), | ||
(736, 1333), (768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
] | ||
]), | ||
dict(type='PackDetInputs') | ||
] | ||
|
||
train_dataloader = dict( | ||
dataset=dict( | ||
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline)) | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.05), | ||
clip_grad=dict(max_norm=0.1, norm_type=2), | ||
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})) | ||
|
||
# learning policy | ||
max_epochs = 12 | ||
train_cfg = dict( | ||
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1) | ||
|
||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', | ||
start_factor=0.0001, | ||
by_epoch=False, | ||
begin=0, | ||
end=2000), | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_epochs, | ||
by_epoch=True, | ||
milestones=[11], | ||
gamma=0.1) | ||
] | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
# USER SHOULD NOT CHANGE ITS VALUES. | ||
# base_batch_size = (8 GPUs) x (2 samples per GPU) | ||
auto_scale_lr = dict(base_batch_size=16) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py' | ||
] | ||
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa: E501 | ||
model = dict( | ||
type='DDQDETR', | ||
num_queries=900, # num_matching_queries | ||
# ratio of num_dense queries to num_queries | ||
dense_topk_ratio=1.5, | ||
with_box_refine=True, | ||
as_two_stage=True, | ||
data_preprocessor=dict( | ||
type='DetDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_size_divisor=1), | ||
backbone=dict( | ||
type='SwinTransformer', | ||
pretrain_img_size=384, | ||
embed_dims=192, | ||
depths=[2, 2, 18, 2], | ||
num_heads=[6, 12, 24, 48], | ||
window_size=12, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
qk_scale=None, | ||
drop_rate=0., | ||
attn_drop_rate=0., | ||
drop_path_rate=0.2, | ||
patch_norm=True, | ||
out_indices=(1, 2, 3), | ||
with_cp=False, | ||
convert_weights=True, | ||
init_cfg=dict(type='Pretrained', checkpoint=pretrained)), | ||
neck=dict( | ||
type='ChannelMapper', | ||
in_channels=[384, 768, 1536], | ||
kernel_size=1, | ||
out_channels=256, | ||
act_cfg=None, | ||
norm_cfg=dict(type='GN', num_groups=32), | ||
num_outs=4), | ||
# encoder class name: DeformableDetrTransformerEncoder | ||
encoder=dict( | ||
num_layers=6, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict(embed_dims=256, num_levels=4, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, # 1024 for DeformDETR | ||
ffn_drop=0.0))), # 0.1 for DeformDETR | ||
# decoder class name: DDQTransformerDecoder | ||
decoder=dict( | ||
num_layers=6, | ||
return_intermediate=True, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict(embed_dims=256, num_heads=8, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
cross_attn_cfg=dict(embed_dims=256, num_levels=4, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, # 1024 for DeformDETR | ||
ffn_drop=0.0)), # 0.1 for DeformDETR | ||
post_norm_cfg=None), | ||
positional_encoding=dict( | ||
num_feats=128, | ||
normalize=True, | ||
offset=0.0, # -0.5 for DeformDETR | ||
temperature=20), # 10000 for DeformDETR | ||
bbox_head=dict( | ||
type='DDQDETRHead', | ||
num_classes=80, | ||
sync_cls_avg_factor=True, | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=5.0), | ||
loss_iou=dict(type='GIoULoss', loss_weight=2.0)), | ||
dn_cfg=dict( | ||
label_noise_scale=0.5, | ||
box_noise_scale=1.0, | ||
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)), | ||
dqs_cfg=dict(type='nms', iou_threshold=0.8), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='FocalLossCost', weight=2.0), | ||
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), | ||
dict(type='IoUCost', iou_mode='giou', weight=2.0) | ||
])), | ||
test_cfg=dict(max_per_img=300)) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=_base_.backend_args), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict( | ||
type='RandomChoice', | ||
transforms=[ | ||
[ | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), | ||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), | ||
(736, 1333), (768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
], | ||
[ | ||
dict( | ||
type='RandomChoiceResize', | ||
# The radio of all image in train dataset < 7 | ||
# follow the original implement | ||
scales=[(400, 4200), (500, 4200), (600, 4200)], | ||
keep_ratio=True), | ||
dict( | ||
type='RandomCrop', | ||
crop_type='absolute_range', | ||
crop_size=(384, 600), | ||
allow_negative_crop=True), | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), | ||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), | ||
(736, 1333), (768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
] | ||
]), | ||
dict(type='PackDetInputs') | ||
] | ||
|
||
train_dataloader = dict( | ||
dataset=dict( | ||
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline)) | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.05), | ||
clip_grad=dict(max_norm=0.1, norm_type=2), | ||
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.05)})) | ||
|
||
# learning policy | ||
max_epochs = 30 | ||
train_cfg = dict( | ||
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1) | ||
|
||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', | ||
start_factor=0.0001, | ||
by_epoch=False, | ||
begin=0, | ||
end=2000), | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_epochs, | ||
by_epoch=True, | ||
milestones=[20, 26], | ||
gamma=0.1) | ||
] | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
# USER SHOULD NOT CHANGE ITS VALUES. | ||
# base_batch_size = (8 GPUs) x (2 samples per GPU) | ||
auto_scale_lr = dict(base_batch_size=16) |
Oops, something went wrong.