Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Objects365 Dataset #7525

Merged
merged 18 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .dev_scripts/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def get_dataset_name(config):
VOCDataset='Pascal VOC',
WIDERFaceDataset='WIDER Face',
OpenImagesDataset='OpenImagesDataset',
OpenImagesChallengeDataset='OpenImagesChallengeDataset')
OpenImagesChallengeDataset='OpenImagesChallengeDataset',
Objects365V1Dataset='Objects365 v1',
Objects365V2Dataset='Objects365 v2')
cfg = mmcv.Config.fromfile('./configs/' + config)
return name_map[cfg.dataset_type]

Expand Down
49 changes: 49 additions & 0 deletions configs/_base_/datasets/objects365v1_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# dataset settings
dataset_type = 'Objects365V1Dataset'
data_root = 'data/Objects365/Obj365_v1/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/objects365_train.json',
img_prefix=data_root + 'train/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/objects365_val.json',
img_prefix=data_root + 'val/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/objects365_val.json',
img_prefix=data_root + 'val/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
49 changes: 49 additions & 0 deletions configs/_base_/datasets/objects365v2_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# dataset settings
dataset_type = 'Objects365V2Dataset'
data_root = 'data/Objects365/Obj365_v2/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/zhiyuan_objv2_train.json',
img_prefix=data_root + 'train/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/zhiyuan_objv2_val.json',
img_prefix=data_root + 'val/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/zhiyuan_objv2_val.json',
img_prefix=data_root + 'val/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
91 changes: 91 additions & 0 deletions configs/objects365/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Objects365 Dataset

> [Objects365 Dataset](https://openaccess.thecvf.com/content_ICCV_2019/papers/Shao_Objects365_A_Large-Scale_High-Quality_Dataset_for_Object_Detection_ICCV_2019_paper.pdf)

<!-- [DATASET] -->

## Abstract

<!-- [ABSTRACT] -->

#### Objects365 Dataset V1

[Objects365 Dataset V1](https://www.objects365.org/overview.html) is a brand new dataset,
designed to spur object detection research with a focus on diverse objects in the Wild.
It has 365 object categories over 600K training images. More than 10 million, high-quality bounding boxes are manually labeled through a three-step, carefully designed annotation pipeline. It is the largest object detection dataset (with full annotation) so far and establishes a more challenging benchmark for the community. Objects365 can serve as a better feature learning dataset for localization-sensitive tasks like object detection
and semantic segmentation.

<!-- [IMAGE] -->

<div align=center>
<img src="https://user-images.githubusercontent.com/48282753/208368046-b7573022-06c9-4a99-af17-a6ac7407e3d8.png" height="400"/>
</div>

#### Objects365 Dataset V2

[Objects365 Dataset V2](https://www.objects365.org/overview.html) is based on the V1 release of the Objects365 dataset.
Objects 365 annotated 365 object classes on more than 1800k images, with more than 29 million bounding boxes in the training set, surpassing PASCAL VOC, ImageNet, and COCO datast.
BIGWangYuDong marked this conversation as resolved.
Show resolved Hide resolved
Objects 365 includes 11 categories of people, clothing, living room, bathroom, kitchen, office/medical, electrical appliances, transportation, food, animals, sports/musical instruments, and each category has dozens of subcategories.

## Citation

```
@inproceedings{shao2019objects365,
title={Objects365: A large-scale, high-quality dataset for object detection},
author={Shao, Shuai and Li, Zeming and Zhang, Tianyuan and Peng, Chao and Yu, Gang and Zhang, Xiangyu and Li, Jing and Sun, Jian},
booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
pages={8430--8439},
year={2019}
}
```

## Prepare Dataset

1. You need to download and extract Objects365 dataset.
BIGWangYuDong marked this conversation as resolved.
Show resolved Hide resolved

2. The directory should be like this:

```none
mmdetection
├── mmdet
├── tools
├── configs
├── data
│ ├── Objects365
│ │ ├── Obj365_v1
│ │ │ ├── annotations
│ │ │ │ ├── objects365_train.json
│ │ │ │ ├── objects365_val.json
│ │ │ ├── train # training images
│ │ │ ├── val # validation images
│ │ ├── Obj365_v2
│ │ │ ├── annotations
│ │ │ │ ├── zhiyuan_objv2_train.json
│ │ │ │ ├── zhiyuan_objv2_val.json
│ │ │ ├── train # training images
│ │ │ │ ├── patch0
│ │ │ │ ├── patch1
│ │ │ │ ├── ...
│ │ │ ├── val # validation images
│ │ │ │ ├── patch0
│ │ │ │ ├── patch1
│ │ │ │ ├── ...
```

## Results and Models

### Objects365 V1

| Architecture | Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
| :----------: | :------: | :-----: | :-----: | :------: | :----: | :-----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Faster R-CNN | R-50 | pytorch | 1x | - | 19.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1/faster_rcnn_r50_fpn_16x4_1x_obj365v1_20221219_181226-9ff10f95.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1/faster_rcnn_r50_fpn_16x4_1x_obj365v1_20221219_181226.log.json) |
| Faster R-CNN | R-50 | pytorch | 1350K | - | 22.3 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1_20220510_142457-337d8965.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1_20220510_142457.log.json) |
| Retinanet | R-50 | pytorch | 1x | - | 14.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/objects365/retinanet_r50_fpn_1x_obj365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v1/retinanet_r50_fpn_1x_obj365v1_20221219_181859-ba3e3dd5.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v1/retinanet_r50_fpn_1x_obj365v1_20221219_181859.log.json) |
| Retinanet | R-50 | pytorch | 1350K | - | 18.0 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/objects365/retinanet_r50_fpn_syncbn_1350k_obj365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_syncbn_1350k_obj365v1/retinanet_r50_fpn_syncbn_1350k_obj365v1_20220513_111237-7517c576.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_syncbn_1350k_obj365v1/retinanet_r50_fpn_syncbn_1350k_obj365v1_20220513_111237.log.json) |

### Objects365 V2

| Architecture | Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
| :----------: | :------: | :-----: | :-----: | :------: | :----: | :------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Faster R-CNN | R-50 | pytorch | 1x | - | 19.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2/faster_rcnn_r50_fpn_16x4_1x_obj365v2_20221220_175040-5910b015.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2/faster_rcnn_r50_fpn_16x4_1x_obj365v2_20221220_175040.log.json) |
| Retinanet | R-50 | pytorch | 1x | - | 16.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/objects365/retinanet_r50_fpn_1x_obj365v2.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v2/retinanet_r50_fpn_1x_obj365v2_20221223_122105-d9b191f1.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v2/retinanet_r50_fpn_1x_obj365v2_20221223_122105.log.json) |
25 changes: 25 additions & 0 deletions configs/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/objects365v1_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

model = dict(roi_head=dict(bbox_head=dict(num_classes=365)))

data = dict(samples_per_gpu=4)

# Using 32 GPUS while training
optimizer = dict(type='SGD', lr=0.08, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=1.0 / 1000,
step=[8, 11])

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (16 GPUs) x (4 samples per GPU)
auto_scale_lr = dict(base_batch_size=64)
25 changes: 25 additions & 0 deletions configs/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/objects365v2_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

model = dict(roi_head=dict(bbox_head=dict(num_classes=365)))

data = dict(samples_per_gpu=4)

# Using 32 GPUS while training
optimizer = dict(type='SGD', lr=0.08, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=1.0 / 1000,
step=[8, 11])

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (16 GPUs) x (4 samples per GPU)
auto_scale_lr = dict(base_batch_size=64)
31 changes: 31 additions & 0 deletions configs/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/objects365v1_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
backbone=dict(norm_cfg=norm_cfg),
roi_head=dict(bbox_head=dict(num_classes=365)))

# Using 8 GPUS while training
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

runner = dict(
_delete_=True, type='IterBasedRunner', max_iters=1350000) # 36 epochs
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=1.0 / 1000,
step=[900000, 1200000])

checkpoint_config = dict(interval=150000)
evaluation = dict(interval=150000, metric='bbox')

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)
101 changes: 101 additions & 0 deletions configs/objects365/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
- Name: retinanet_r50_fpn_1x_obj365v1
In Collection: RetinaNet
Config: configs/objects365/retinanet_r50_fpn_1x_obj365v1.py
Metadata:
Training Memory (GB): 7.4
Epochs: 12
Training Data: Objects365 v1
Training Techniques:
- SGD with Momentum
- Weight Decay
Results:
- Task: Object Detection
Dataset: Objects365 v1
Metrics:
box AP: 14.8
Weights: https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v1/retinanet_r50_fpn_1x_obj365v1_20221219_181859-ba3e3dd5.pth

- Name: retinanet_r50_fpn_syncbn_1350k_obj365v1
In Collection: RetinaNet
Config: configs/objects365/retinanet_r50_fpn_syncbn_1350k_obj365v1.py
Metadata:
Training Memory (GB): 7.6
Iterations: 1350000
Training Data: Objects365 v1
Training Techniques:
- SGD with Momentum
- Weight Decay
Results:
- Task: Object Detection
Dataset: Objects365 v1
Metrics:
box AP: 18.0
Weights: https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_syncbn_1350k_obj365v1/retinanet_r50_fpn_syncbn_1350k_obj365v1_20220513_111237-7517c576.pth

- Name: retinanet_r50_fpn_1x_obj365v2
In Collection: RetinaNet
Config: configs/objects365/retinanet_r50_fpn_1x_obj365v2.py
Metadata:
Training Memory (GB): 7.2
Epochs: 12
Training Data: Objects365 v2
Training Techniques:
- SGD with Momentum
- Weight Decay
Results:
- Task: Object Detection
Dataset: Objects365 v2
Metrics:
box AP: 16.7
Weights: https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v2/retinanet_r50_fpn_1x_obj365v2_20221223_122105-d9b191f1.pth

- Name: faster_rcnn_r50_fpn_16x4_1x_obj365v1
In Collection: Faster R-CNN
Config: configs/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1.py
Metadata:
Training Memory (GB): 11.4
Epochs: 12
Training Data: Objects365 v1
Training Techniques:
- SGD with Momentum
- Weight Decay
Results:
- Task: Object Detection
Dataset: Objects365 v1
Metrics:
box AP: 19.6
Weights: https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1/faster_rcnn_r50_fpn_16x4_1x_obj365v1_20221219_181226-9ff10f95.pth

- Name: faster_rcnn_r50_fpn_syncbn_1350k_obj365v1
In Collection: Faster R-CNN
Config: configs/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1.py
Metadata:
Training Memory (GB): 8.6
Iterations: 1350000
Training Data: Objects365 v1
Training Techniques:
- SGD with Momentum
- Weight Decay
Results:
- Task: Object Detection
Dataset: Objects365 v1
Metrics:
box AP: 22.3
Weights: https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1_20220510_142457-337d8965.pth

- Name: faster_rcnn_r50_fpn_16x4_1x_obj365v2
In Collection: Faster R-CNN
Config: configs/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2.py
Metadata:
Training Memory (GB): 10.8
Epochs: 12
Training Data: Objects365 v1
Training Techniques:
- SGD with Momentum
- Weight Decay
Results:
- Task: Object Detection
Dataset: Objects365 v2
Metrics:
box AP: 19.8
Weights: https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2/faster_rcnn_r50_fpn_16x4_1x_obj365v2_20221220_175040-5910b015.pth
Loading