Skip to content

xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters

License

Notifications You must be signed in to change notification settings

xdit-project/xDiT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

xDiT

A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters

πŸ“ƒ Paper | πŸš€ Quick Start | 🎯 Supported DiTs | πŸ“š Dev Guide | πŸ“ˆ Discussion

Table of Contents

πŸ”₯ Meet xDiT

Diffusion Transformers (DiTs) are driving advancements in high-quality image and video generation. With the escalating input context length in DiTs, the computational demand of the Attention mechanism grows quadratically! Consequently, multi-GPU and multi-machine deployments are essential to meet the real-time requirements in online services.

To meet real-time demand for DiTs applications, parallel inference is a must. xDiT is an inference engine designed for the parallel deployment of DiTs on large scale. xDiT provides a suite of efficient parallel approaches for Diffusion Models, as well as GPU kernel accelerations.

  1. Sequence Parallelism, USP is a unified sequence parallel approach combining DeepSpeed-Ulysses, Ring-Attention.

  2. PipeFusion, a patch level pipeline parallelism using displaced patch by taking advantage of the diffusion model characteristics.

  3. Data Parallel: Processes multiple prompts or generates multiple images from a single prompt in parallel across images.

  4. CFG Parallel, also known as Split Batch: Activates when using classifier-free guidance (CFG) with a constant parallelism of 2.

The four parallel methods in xDiT can be configured in a hybrid manner, optimizing communication patterns to best suit the underlying network hardware.

As shown in the following picture, xDiT offers a set of APIs to adapt DiT models in huggingface/diffusers to hybrid parallel implementation through simple wrappers. If the model you require is not available in the model zoo, developing it yourself is straightforward; please refer to our Dev Guide.

We also have implemented the following parallel stategies for reference:

  1. Tensor Parallelism
  2. DistriFusion

Optimization orthogonal to parallelization focuses on accelerating single GPU performance. In addition to utilizing well-known Attention optimization libraries, we leverage compilation acceleration technologies such as torch.compile and onediff.

The overview of xDiT is shown as follows.

xDiT

πŸ“’ Updates

🎯 Supported DiTs

Model Name CFG SP PipeFusion
🎬 CogVideoX ❎ ❎ ❎
🎬 Latte ❎ βœ”οΈ ❎
πŸ”΅ HunyuanDiT-v1.2-Diffusers βœ”οΈ βœ”οΈ βœ”οΈ
🟠 Flux NA βœ”οΈ ❎
πŸ”΄ PixArt-Sigma βœ”οΈ βœ”οΈ βœ”οΈ
🟒 PixArt-alpha βœ”οΈ βœ”οΈ βœ”οΈ
🟠 Stable Diffusion 3 βœ”οΈ βœ”οΈ βœ”οΈ

Supported by legacy version only, including DistriFusion and Tensor Parallel as the standalong parallel strategies:

πŸ“ˆ Performance

Flux.1

  1. Flux Performance Report

HunyuanDiT

  1. HunyuanDiT Performance Report

SD3

  1. Stable Diffusion 3 Performance Report

Pixart

  1. Pixart-Alpha Performance Report (legacy)

Pixart

  1. Latte Performance Report

πŸš€ QuickStart

1. Install from pip (current version)

pip install xfuser

2. Install from source

python setup.py install

Note that we use two self-maintained packages:

  1. yunchang
  2. DistVAE

The flash_attn used for yunchang should be >= 2.6.0

3. Launch a Http Service

Launching a Text-to-Image Http Service

4. Usage

We provide examples demonstrating how to run models with xDiT in the ./examples/ directory. You can easily modify the model type, model directory, and parallel options in the examples/run.sh within the script to run some already supported DiT models.

bash examples/run.sh

To inspect the available options for the PixArt-alpha example, use the following command:

python ./examples/pixartalpha_example.py -h

...

xFuser Arguments

options:
  -h, --help            show this help message and exit

Model Options:
  --model MODEL         Name or path of the huggingface model to use.
  --download-dir DOWNLOAD_DIR
                        Directory to download and load the weights, default to the default cache dir of huggingface.
  --trust-remote-code   Trust remote code from huggingface.

Runtime Options:
  --warmup_steps WARMUP_STEPS
                        Warmup steps in generation.
  --use_parallel_vae
  --use_torch_compile   Enable torch.compile to accelerate inference in a single card
  --seed SEED           Random seed for operations.
  --output_type OUTPUT_TYPE
                        Output type of the pipeline.
  --enable_sequential_cpu_offload
                        Offloading the weights to the CPU.

Parallel Processing Options:
  --use_cfg_parallel    Use split batch in classifier_free_guidance. cfg_degree will be 2 if set
  --data_parallel_degree DATA_PARALLEL_DEGREE
                        Data parallel degree.
  --ulysses_degree ULYSSES_DEGREE
                        Ulysses sequence parallel degree. Used in attention layer.
  --ring_degree RING_DEGREE
                        Ring sequence parallel degree. Used in attention layer.
  --pipefusion_parallel_degree PIPEFUSION_PARALLEL_DEGREE
                        Pipefusion parallel degree. Indicates the number of pipeline stages.
  --num_pipeline_patch NUM_PIPELINE_PATCH
                        Number of patches the feature map should be segmented in pipefusion parallel.
  --attn_layer_num_for_pp [ATTN_LAYER_NUM_FOR_PP ...]
                        List representing the number of layers per stage of the pipeline in pipefusion parallel
  --tensor_parallel_degree TENSOR_PARALLEL_DEGREE
                        Tensor parallel degree.
  --split_scheme SPLIT_SCHEME
                        Split scheme for tensor parallel.

Input Options:
  --height HEIGHT       The height of image
  --width WIDTH         The width of image
  --prompt [PROMPT ...]
                        Prompt for the model.
  --no_use_resolution_binning
  --negative_prompt [NEGATIVE_PROMPT ...]
                        Negative prompt for the model.
  --num_inference_steps NUM_INFERENCE_STEPS
                        Number of inference steps.

Hybriding multiple parallelism techniques togather is essential for efficiently scaling. It's important that the product of all parallel degrees matches the number of devices. For instance, you can combine CFG, PipeFusion, and sequence parallelism with the command below to generate an image of a cute dog through hybrid parallelism. Here ulysses_degree * pipefusion_parallel_degree * cfg_degree(use_split_batch) == number of devices == 8.

torchrun --nproc_per_node=8 \
examples/pixartalpha_example.py \
--model models/PixArt-XL-2-1024-MS \
--pipefusion_parallel_degree 2 \
--ulysses_degree 2 \
--num_inference_steps 20 \
--warmup_steps 0 \
--prompt "A small dog" \
--use_cfg_parallel

⚠️ Applying PipeFusion requires setting warmup_steps, also required in DistriFusion, typically set to a small number compared with num_inference_steps. The warmup step impacts the efficiency of PipeFusion as it cannot be executed in parallel, thus degrading to a serial execution. We observed that a warmup of 0 had no effect on the PixArt model. Users can tune this value according to their specific tasks.

✨ The xDiT's Arsenal

The remarkable performance of xDiT is attributed to two key facets. Firstly, it leverages parallelization techniques, pioneering innovations such as USP, PipeFusion, and hybrid parallelism, to scale DiTs inference to unprecedented scales.

Secondly, we employ compilation technologies to enhance execution on GPUs, integrating established solutions like torch.compile and onediff to optimize xDiT's performance.

1. Parallel Methods

As illustrated in the accompanying images, xDiTs offer a comprehensive set of parallelization techniques. For the DiT backbone, the foundational methodsβ€”Data, USP, PipeFusion, and CFG parallelβ€”operate in a hybrid fashion. Additionally, the distinct methods, Tensor and DistriFusion parallel, function independently. For the VAE module, xDiT offers a parallel implementation, DistVAE, designed to prevent out-of-memory (OOM) issues. The (xDiT) highlights the methods first proposed by use.

xdit methods

The communication and memory costs associated with the aforementioned intra-image parallelism, except for the CFG and DP (they are inter-image parallel), in DiTs are detailed in the table below. (* denotes that communication can be overlapped with computation.)

As we can see, PipeFusion and Sequence Parallel achieve lowest communication cost on different scales and hardware configurations, making them suitable foundational components for a hybrid approach.

𝒑: Number of pixels; 𝒉𝒔: Model hidden size; 𝑳: Number of model layers; 𝑷: Total model parameters; 𝑡: Number of parallel devices; 𝑴: Number of patch splits; 𝑸𝑢: Query and Output parameter count; 𝑲𝑽: KV Activation parameter count; 𝑨 = 𝑸 = 𝑢 = 𝑲 = 𝑽: Equal parameters for Attention, Query, Output, Key, and Value;

attn-KV communication cost param memory activations memory extra buff memory
Tensor Parallel fresh $4O(p \times hs)L$ $\frac{1}{N}P$ $\frac{2}{N}A = \frac{1}{N}QO$ $\frac{2}{N}A = \frac{1}{N}KV$
DistriFusion* stale $2O(p \times hs)L$ $P$ $\frac{2}{N}A = \frac{1}{N}QO$ $2AL = (KV)L$
Ring Sequence Parallel* fresh $2O(p \times hs)L$ $P$ $\frac{2}{N}A = \frac{1}{N}QO$ $\frac{2}{N}A = \frac{1}{N}KV$
Ulysses Sequence Parallel fresh $\frac{4}{N}O(p \times hs)L$ $P$ $\frac{2}{N}A = \frac{1}{N}QO$ $\frac{2}{N}A = \frac{1}{N}KV$
PipeFusion* stale- $2O(p \times hs)$ $\frac{1}{N}P$ $\frac{2}{M}A = \frac{1}{M}QO$ $\frac{2L}{N}A = \frac{1}{N}(KV)L$

1.1. PipeFusion

PipeFusion: Displaced Patch Pipeline Parallelism for Diffusion Models

1.2. USP: Unified Sequence Parallelism

USP: A Unified Sequence Parallelism Approach for Long Context Generative AI

1.3. Hybrid Parallel

Hybrid Parallelism

1.4. CFG Parallel

CFG Parallel

1.5. Parallel VAE

Patch Parallel VAE

Compilation Acceleration

We utilize two compilation acceleration techniques, torch.compile and onediff, to enhance runtime speed on GPUs. These compilation accelerations are used in conjunction with parallelization methods.

We employ the nexfort backend of onediff. Please install it before use:

pip install onediff
pip install -U nexfort

For usage instructions, refer to the example/run.sh. Simply append --use_torch_compile or --use_onediff to your command. Note that these options are mutually exclusive, and their performance varies across different scenarios.

πŸ“š Develop Guide

The implement and design of xdit framework

Manual for adding new models

🚧 History and Looking for Contributions

We conducted a major upgrade of this project in August 2024.

The latest APIs is located in the xfuser/ directory, supports hybrid parallelism. It offers clearer and more structured code but currently supports fewer models.

The legacy APIs is in the legacy/ directory, limited to single parallelism. It supports a richer of parallel methods, including PipeFusion, Sequence Parallel, DistriFusion, and Tensor Parallel. CFG Parallel can be hybrid with PipeFusion but not with other parallel methods.

For models not yet supported by the latest APIs, you can run the examples in the legacy/scripts/ directory. If you wish to develop new features on a model or require hybrid parallelism, stay tuned for further project updates.

We also welcome developers to join and contribute more features and models to the project. Tell us which model you need in xDiT in discussions.

πŸ“ Cite Us

@article{wang2024pipefusion,
      title={PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models}, 
      author={Jiannan Wang and Jiarui Fang and Jinzhe Pan and Aoyu Li and PengCheng Yang},
      year={2024},
      eprint={2405.07719},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@article{fang2024unified,
      title={USP: a Unified Sequence Parallelism Approach for Long Context Generative AI},
      author={Fang, Jiarui and Zhao, Shangchun},
      journal={arXiv preprint arXiv:2405.07719},
      year={2024}
}