Skip to content

Commit

Permalink
[Feature] Support VPD Depth Estimator (#3321)
Browse files Browse the repository at this point in the history
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation


Support depth estimation algorithm [VPD](https://github.com/wl-zhao/VPD)

## Modification

1. add VPD backbone
2. add VPD decoder head for depth estimation
3. add a new segmentor `DepthEstimator` based on `EncoderDecoder` for
depth estimation
4. add an integrated metric that calculate common metrics in depth
estimation
5. add SiLog loss for depth estimation 
6. add config for VPD 

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases
here, and update the documentation.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
7. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
8. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
9. The documentation has been modified accordingly, like docstring or
example tutorials.
  • Loading branch information
Ben-Louis committed Sep 13, 2023
1 parent ebd5695 commit c46cc85
Show file tree
Hide file tree
Showing 33 changed files with 2,216 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ jobs:
docker exec mmseg mim install mmcv>=2.0.0
docker exec mmseg pip install mmpretrain>=1.0.0rc7
docker exec mmseg mim install mmdet>=3.0.0
docker exec mmseg apt-get update
docker exec mmseg apt-get install -y git
docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt
docker exec mmseg python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
- run:
Expand Down
20 changes: 11 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ repos:
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: local
hooks:
- id: update-model-index
name: update-model-index
description: Collect model information and update model-index.yml
entry: .dev_scripts/update_model_index.py
additional_dependencies: [pyyaml]
language: python
require_serial: true
# temporarily remove update-model-index to avoid conflict raised
# by depth estimator models
# - repo: local
# hooks:
# - id: update-model-index
# name: update-model-index
# description: Collect model information and update model-index.yml
# entry: .dev_scripts/update_model_index.py
# additional_dependencies: [pyyaml]
# language: python
# require_serial: true
- repo: https://github.com/asottile/pyupgrade
rev: v3.0.0
hooks:
Expand Down
66 changes: 66 additions & 0 deletions configs/_base_/datasets/nyu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# dataset settings
dataset_type = 'NYUDataset'
data_root = 'data/nyu'

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3),
dict(type='RandomDepthMix', prob=0.25),
dict(type='RandomFlip', prob=0.5),
dict(type='RandomCrop', crop_size=(480, 480)),
dict(
type='Albu',
transforms=[
dict(type='RandomBrightnessContrast'),
dict(type='RandomGamma'),
dict(type='HueSaturationValue'),
]),
dict(
type='PackSegInputs',
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'category_id')),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
dict(
type='PackSegInputs',
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'category_id'))
]

train_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/train', depth_map_path='annotations/train'),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(
img_path='images/test', depth_map_path='annotations/test'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(
type='DepthMetric',
min_depth_eval=0.001,
max_depth_eval=10.0,
crop_type='nyu_crop')
test_evaluator = val_evaluator
86 changes: 86 additions & 0 deletions configs/_base_/models/vpd_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# model settings
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)

# adapted from stable-diffusion/configs/stable-diffusion/v1-inference.yaml
stable_diffusion_cfg = dict(
base_learning_rate=0.0001,
target='ldm.models.diffusion.ddpm.LatentDiffusion',
checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/'
'vpd/stable_diffusion_v1-5_pretrain_third_party.pth',
params=dict(
linear_start=0.00085,
linear_end=0.012,
num_timesteps_cond=1,
log_every_t=200,
timesteps=1000,
first_stage_key='jpg',
cond_stage_key='txt',
image_size=64,
channels=4,
cond_stage_trainable=False,
conditioning_key='crossattn',
monitor='val/loss_simple_ema',
scale_factor=0.18215,
use_ema=False,
scheduler_config=dict(
target='ldm.lr_scheduler.LambdaLinearScheduler',
params=dict(
warm_up_steps=[10000],
cycle_lengths=[10000000000000],
f_start=[1e-06],
f_max=[1.0],
f_min=[1.0])),
unet_config=dict(
target='ldm.modules.diffusionmodules.openaimodel.UNetModel',
params=dict(
image_size=32,
in_channels=4,
out_channels=4,
model_channels=320,
attention_resolutions=[4, 2, 1],
num_res_blocks=2,
channel_mult=[1, 2, 4, 4],
num_heads=8,
use_spatial_transformer=True,
transformer_depth=1,
context_dim=768,
use_checkpoint=True,
legacy=False)),
first_stage_config=dict(
target='ldm.models.autoencoder.AutoencoderKL',
params=dict(
embed_dim=4,
monitor='val/rec_loss',
ddconfig=dict(
double_z=True,
z_channels=4,
resolution=256,
in_channels=3,
out_ch=3,
ch=128,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_resolutions=[],
dropout=0.0),
lossconfig=dict(target='torch.nn.Identity'))),
cond_stage_config=dict(
target='ldm.modules.encoders.modules.AbstractEncoder')))

model = dict(
type='DepthEstimator',
data_preprocessor=data_preprocessor,
backbone=dict(
type='VPD',
diffusion_cfg=stable_diffusion_cfg,
),
)

# some of the parameters in stable-diffusion model will not be updated
# during training
find_unused_parameters = True
28 changes: 28 additions & 0 deletions configs/_base_/schedules/schedule_25k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# optimizer
optimizer = dict(type='AdamW', lr=0.001, weight_decay=0.1)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='LinearLR', start_factor=3e-2, begin=0, end=12000,
by_epoch=False),
dict(
type='PolyLRRatio',
eta_min_ratio=3e-2,
power=0.9,
begin=12000,
end=24000,
by_epoch=False),
dict(type='ConstantLR', by_epoch=False, factor=1, begin=24000, end=25000)
]
# training schedule for 25k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=25000, val_interval=1000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
49 changes: 49 additions & 0 deletions configs/vpd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# VPD

> [Unleashing Text-to-Image Diffusion Models for Visual Perception](https://arxiv.org/abs/2303.02153)
## Introduction

<!-- [BACKBONE] -->

<a href = "https://github.com/wl-zhao/VPD">Official Repo</a>

## Abstract

<!-- [ABSTRACT] -->

Diffusion models (DMs) have become the new trend of generative models and have demonstrated a powerful ability of conditional synthesis. Among those, text-to-image diffusion models pre-trained on large-scale image-text pairs are highly controllable by customizable prompts. Unlike the unconditional generative models that focus on low-level attributes and details, text-to-image diffusion models contain more high-level knowledge thanks to the vision-language pre-training. In this paper, we propose VPD (Visual Perception with a pre-trained Diffusion model), a new framework that exploits the semantic information of a pre-trained text-to-image diffusion model in visual perception tasks. Instead of using the pre-trained denoising autoencoder in a diffusion-based pipeline, we simply use it as a backbone and aim to study how to take full advantage of the learned knowledge. Specifically, we prompt the denoising decoder with proper textual inputs and refine the text features with an adapter, leading to a better alignment to the pre-trained stage and making the visual contents interact with the text prompts. We also propose to utilize the cross-attention maps between the visual features and the text features to provide explicit guidance. Compared with other pre-training methods, we show that vision-language pre-trained diffusion models can be faster adapted to downstream visual perception tasks using the proposed VPD. Extensive experiments on semantic segmentation, referring image segmentation and depth estimation demonstrates the effectiveness of our method. Notably, VPD attains 0.254 RMSE on NYUv2 depth estimation and 73.3% oIoU on RefCOCO-val referring image segmentation, establishing new records on these two benchmarks.

<!-- [IMAGE] -->

<div align=center>
<img src="https://github.com/open-mmlab/mmsegmentation/assets/26127467/88f5752d-7fe2-4cb0-a284-8ee0680e29cd" width="80%"/>
</div>

## Usage

To run training or inference with VPD model, please install the required packages via

```sh
pip install -r requirements/albu.txt
pip install -r requirements/optional.txt
```

## Results and models

### NYU

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | RMSE | d1 | d2 | d3 | REL | log_10 | config | download |
| ------ | --------------------- | --------- | ------- | -------- | -------------- | ------ | ----- | ----- | ----- | ----- | ----- | ------ | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| VPD | Stable-Diffusion-v1-5 | 480x480 | 25000 | - | - | A100 | 0.253 | 0.964 | 0.995 | 0.999 | 0.069 | 0.030 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908-66144bc4.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908.json) |

## Citation

```bibtex
@article{zhao2023unleashing,
title={Unleashing Text-to-Image Diffusion Models for Visual Perception},
author={Zhao, Wenliang and Rao, Yongming and Liu, Zuyan and Liu, Benlin and Zhou, Jie and Lu, Jiwen},
journal={ICCV},
year={2023}
}
```
34 changes: 34 additions & 0 deletions configs/vpd/metafile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Collections:
- Name: VPD
License: Apache License 2.0
Metadata:
Training Data:
- NYU
Paper:
Title: Unleashing Text-to-Image Diffusion Models for Visual Perception
URL: https://arxiv.org/abs/2303.02153
README: configs/vpd/README.md
Frameworks:
- PyTorch
Models:
- Name: vpd_sd_4xb8-25k_nyu-480x480
In Collection: VPD
Results:
Task: Depth Estimation
Dataset: NYU
Metrics:
RMSE: 0.253
Config: configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py
Metadata:
Training Data: NYU
Batch Size: 32
Architecture:
- Stable-Diffusion
Training Resources: 8x A100 GPUS
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908-66144bc4.pth
Training log: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908.json
Paper:
Title: 'High-Resolution Image Synthesis with Latent Diffusion Models'
URL: https://arxiv.org/abs/2112.10752
Code: https://github.com/open-mmlab/mmsegmentation/tree/main/mmseg/models/backbones/vpd.py#L333
Framework: PyTorch
37 changes: 37 additions & 0 deletions configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_base_ = [
'../_base_/models/vpd_sd.py', '../_base_/datasets/nyu.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_25k.py'
]

crop_size = (480, 480)

model = dict(
type='DepthEstimator',
data_preprocessor=dict(size=crop_size),
backbone=dict(
class_embed_path='https://download.openmmlab.com/mmsegmentation/'
'v0.5/vpd/nyu_class_embeddings.pth',
class_embed_select=True,
pad_shape=512,
unet_cfg=dict(use_attn=False),
),
decode_head=dict(
type='VPDDepthHead',
in_channels=[320, 640, 1280, 1280],
max_depth=10,
fmap_border=(1, 1),
),
test_cfg=dict(mode='slide_flip', crop_size=crop_size, stride=(160, 160)))

default_hooks = dict(checkpoint=dict(save_best='rmse', rule='less'))

# custom optimizer
optim_wrapper = dict(
type='ForceDefaultOptimWrapperConstructor',
paramwise_cfg=dict(
bias_decay_mult=0,
force_default_settings=True,
custom_keys={
'backbone.encoder_vq': dict(lr_mult=0),
'backbone.unet': dict(lr_mult=0.01),
}))
6 changes: 4 additions & 2 deletions mmseg/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
RandomDepthMix, RandomFlip, RandomMosaic,
RandomRotate, RandomRotFlip, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)

Expand All @@ -24,5 +25,6 @@
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation'
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
'RandomFlip'
]
3 changes: 3 additions & 0 deletions mmseg/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,8 @@ class LoadDepthAnnotation(BaseTransform):
- gt_depth_map (np.ndarray): Depth map with shape (Y, X) by
default, and data type is float32 if set to_float32 = True.
- depth_rescale_factor (float): The rescale factor of depth map, which
can be used to recover the original value of depth map.
Args:
decode_backend (str): The data decoding backend type. Options are
Expand Down Expand Up @@ -691,6 +693,7 @@ def transform(self, results: Dict) -> Dict:
gt_depth_map *= self.depth_rescale_factor
results['gt_depth_map'] = gt_depth_map
results['seg_fields'].append('gt_depth_map')
results['depth_rescale_factor'] = self.depth_rescale_factor
return results

def __repr__(self):
Expand Down
Loading

0 comments on commit c46cc85

Please sign in to comment.