Skip to content

The official implementation of "Relay Diffusion: Unifying diffusion process across resolutions for image synthesis" [ICLR 2024 Spotlight]

License

Notifications You must be signed in to change notification settings

THUDM/RelayDiffusion

Repository files navigation

Relay Diffusion: Unifying diffusion process across resolutions for image synthesis
Official Pytorch Implementation 🌐[WiseModel] 🌐[Model Scope]

🎉News! The paper of RelayDiffusion has been accepted by ICLR 2024 (Spotlight)!

We propose Relay Diffusion Model (RDM) as a better framework for diffusion generation. RDM transfers a low-resolution image or noise into an equivalent high-resolution one via blurring diffusion and block noise. Therefore, the diffusion process can continue seamlessly in any new resolution or model without restarting from pure noise or low-resolution conditioning.

RDM achieved state-of-the-art FID on CelebA-HQ and sFID ImageNet-256 (FID=1.87)!

For a formal introduction, Read our paper: Relay Diffusion: Unifying diffusion process across resolutions for image synthesis.

Setup

Environment

Download the repo and setup the environment with:

git clone https://github.com/THUDM/RelayDiffusion.git
cd RelayDiffusion
conda env create -f environment.yml
conda activate rdm

We enable xformers.ops.memory_efficient_attention to reduce about 15% training cost. If there is no need you can also remove xformers from environment.yml.

Linux servers with Nvidia A100s are recommended. However, by setting smaller --batch-gpu (batch size on a single gpu), you can still run the inference and training scripts on less powerful GPUs.

Dataset

We preprocess and implement datasets with the same format as EDM. For CelebA-HQ, follow Progressive Growing of GANs for Improved Quality, Stability, and Variation to construct the high-quality subset of CelebA. For ImageNet, download data from the official site.

To convert the original data to organized data ready for training at $64\times 64$ or $256\times 256$ resolution, run command:

python dataset_tool.py \
	--source=/path/to/original/data \
	--dest=/path/to/output/data.zip \
    --transform=center-crop \
	--resolution=64x64 # or --resolution=256x256

Inference & Evaluation

Sample Generation

To generate samples from RDM models, run command:

torchrun --standalone --nproc_per_node=1 generate.py --sampler_stages=both --outdir=/path/to/output/dir/ \
    --network_first=/path/to/1st/ckpt --network_second=/path/to/2nd/ckpt

To generate $N$ images, set --seed=[K]-[K+N-1] with a randomly-picked $K$. You can assign --nproc_per_node=N to enable parallel generation of multiple GPUs.

If you want to generate final samples from first-stage results (only use the second stage model), set --sampler_stages=second and assign input directory of first-stage results by --indir.

Besides, arguments for configurations of the first stage are:

  • num_steps_first: number of sampling steps.
  • sigma_min_first & sigma_max_first: lowest & highest noise level.
  • rho_first: time step exponent.
  • cfg_scale_first: scale of classifier-free guidance.
  • S_churn: stochasticity strength.
  • S_min & S_max: min & max noise level.
  • S_noise: noise inflation.

Arguments for configurations of the second stage are:

  • num_steps_second: number of sampling steps.
  • sigma_min_second & sigma_max_second: lowest & highest noise level.
  • blur_sigma_max_second: maximum sigma of blurring schedule.
  • rho_second: time step exponent.
  • cfg_scale_second: scale of classifier-free guidance.
  • up_scale_second: scale of upsampling.
  • truncation_sigma_second & truncation_t_second: truncation point of noise & time schedule.
  • s_block_second: strength of block noise addition.
  • s_noise_second: strength of stochasticity.

Evaluation Metrics

We quantitatively measure the sample quality by metrics including Fréchet inception distance (FID), spatial FID (sFID), Inception Score (IS), Precision and Recall. For sFID, IS, Precision and Recall, we reformat the calculation pipeline based on the formulation in tensorflow from ADM.

First, run the following command to generate activation data file from samples and dataset:

torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/sample/dir/ --dest=eval-refs/activations_sample.npz --batch=64 # build sample activations
torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/path/to/dataset.zip --dest=eval-refs/activations_ref.npz --batch=64 # build reference activations

Then calculate metrics based on pre-built activations, run command:

torchrun --standalone --nproc_per_node=1 evaluate.py calc --batch=64 \
    --activations_sample=eval-refs/activations_sample.npz \
    --activations_ref=eval-refs/activations_ref.npz \
    [-m fid] [-m sfid] [-m is] [-m pr] \ # assign metrics to be calculated

Performance Reproduction

RDM achieves competitive results in comparison with previous SoTA models:

Dataset Resolution Training Samples FID sFID IS Precision Recall
CelebA-HQ 256x256 47M 3.15 - - 0.77 0.55
ImageNet 256x256 1250M 1.87 3.97 278.75 0.81 0.59

We provide best pre-trained checkpoints of RDM and their sampler settings for reproducing performance:

  • CelebA-HQ $256\times 256$:

    Download checkpoints of first stage and second stage, place them in ckpts/, generate samples and their activations by commands:

    torchrun --standalone --nproc_per_node=8 generate_celebahq.py --outdir=generations/celebahq_samples/ \
        --network_first=ckpts/celebahq_first_stage.pt \
        --network_second=ckpts/celebahq_second_stage.pt
    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=generations/celebahq_samples/ --dest=eval-refs/celebahq_act_sample.npz 

    Generate activation data from CelebA-HQ zip or download our version from here:

    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=datasets/celebahq-256x256.zip --dest=eval-refs/celebahq_act_ref.npz 

    Calculate metrics by command:

    python evaluate.py calc -m fid -m pr \
        --activations_sample=eval-refs/celebahq_act_sample.npz \
        --activations_ref=eval-refs/celebahq_act_ref.npz
  • ImageNet $256\times 256$:

    Download checkpoints of first stage and second stage, place them in ckpts/, generate samples and their activations by commands:

    torchrun --standalone --nproc_per_node=8 generate_imagenet.py --outdir=generations/imagenet_samples/ \
        --network_first=ckpts/imagenet_first_stage.pkl \
        --network_second=ckpts/imagenet_second_stage.pt
    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=generations/imagenet_samples/ --dest=eval-refs/imagenet_act_sample.npz 

    Generate activation data from ImageNet zip: (The activation data of 128w images is up to 40GB, which is too big to upload. We only upload mu and sigma of reference and samples for calculating FID here. For more metrics, you need to generate activation by yourself.)

    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=datasets/imagenet-256x256.zip --dest=eval-refs/imagenet_act_ref.npz 

    Calculate FID, sFID and IS by command:

    python evaluate.py calc -m fid -m sfid -m is \
        --activations_sample=eval-refs/imagenet_act_sample.npz \
        --activations_ref=eval-refs/imagenet_act_ref.npz

    For the calculation of Precision and Recall on ImageNet, we follow ADM to use 1w reference samples. You can download the activation data we produced from here. Then run the following command:

    python evaluate.py calc -m pr \
        --activations_sample=eval-refs/imagenet_act_sample.npz \
        --activations_ref=eval-refs/imagenet_act_1w_ref.npz

Training

you can follow the instruction of EDM to train a new model of the first stage (standard diffusion). Using ImageNet for example, run command:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/imagenet-64x64.zip --eff-attn=True \
	--cond=1 --batch=4096  --batch-gpu=32 --lr=1e-4 --ema=50 --dropout=0.1 --fp16=1 --ls=25 \
	--arch=adm --precond=edm

If you want to train a second stage model (blurring diffusion), set argument --precond=blur and other arguments for the configuration of blurring diffusion. The command will be:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/imagenet-256x256.zip --eff-attn=True \
	--cond=1 --batch=4096  --batch-gpu=8 --lr=1e-4 --dropout=0.1 --fp16=1 --ls=1 \
	--arch=adm --precond=blur --up-scale=4 --block-scale=0.15 --prob-length=0.93 --blur-sigma-max=3.0

As for CelebA-HQ, train a first stage model with:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/CelebA-HQ-64x64.zip --eff-attn=True \
	--cond=0 --batch=1024  --batch-gpu=32 --lr=1e-4 --dropout=0.15 --augment=0.2 --ls=1 \
	--arch=adm --precond=edm

And for training a second stage model:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/CelebA-HQ-256x256.zip --eff-attn=True \
	--cond=0 --batch=1024  --batch-gpu=8 --lr=1e-4 --dropout=0.2 --augment=0.2 --fp16=1 --ls=1 \
	--arch=adm --precond=blur --up-scale=4 --block-scale=0.15 --prob-length=0.89 --blur-sigma-max=2.0

Citation

@article{teng2023relay,
  title={Relay Diffusion: Unifying diffusion process across resolutions for image synthesis},
  author={Teng, Jiayan and Zheng, Wendi and Ding, Ming and Hong, Wenyi and Wangni, Jianqiao and Yang, Zhuoyi and Tang, Jie},
  journal={arXiv preprint arXiv:2309.03350},
  year={2023}
}

Acknowledgements

This implementation is based on https://github.com/NVlabs/edm (codebase of EDM). Thanks a lot!

About

The official implementation of "Relay Diffusion: Unifying diffusion process across resolutions for image synthesis" [ICLR 2024 Spotlight]

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages