This repo is a collection of attention mechanisms in vision Transformers. Beside the re-implementation, it provides a benchmark on model parameters, FLOPs and CPU/GPU throughput.
- Pytorch 1.8+
- timm
- ninja
- einops
- fvcore
- matplotlib
- NVIDIA RTX 3090
- Intel® Core™ i9-10900X CPU @ 3.70GHz
- Memory 32GB
- Ubuntu 22.04
- PyTorch 1.8.1 + CUDA 11.1
- input: 14 x 14 = 196 tokens (1/16 scale feature maps in common ImageNet-1K training)
- batch size for speed testing (images/s): 64
- embedding dimension:768
- number of heads: 12
For example, to test HiLo attention,
cd attentions/
python hilo.py
By default, the script will test models on both CPU and GPU. FLOPs is measured by fvcore. You may want to edit the source file as needed.
Outputs:
Number of Params: 2.2 M
FLOPs = 298.3 M
throughput averaged with 30 times
batch_size 64 throughput on CPU 1029
throughput averaged with 30 times
batch_size 64 throughput on GPU 5104
- MSA: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. [Paper] [Code]
- Cross Window: CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows. [Paper] [Code]
- DAT: Vision Transformer with Deformable Attention. [Paper] [Code]
- Performer: Rethinking Attention with Performers. [Paper] [Code]
- Linformer: Linformer: Self-Attention with Linear Complexity. [Paper] [Code]
- SRA: Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions. [Paper] [Code]
- Local/Shifted Window: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. [Paper] [Code]
- Focal: Focal Self-attention for Local-Global Interactions in Vision Transformers. [Paper] [Code]
- XCA: XCiT: Cross-Covariance Image Transformers. [Paper] [Code]
- QuadTree: QuadTree Attention for Vision Transformers. [Paper] [Code]
- VAN: Visual Attention Network. [Paper] [Code]
- HorNet: HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions. [Paper] [Code]
- HiLo: Fast Vision Transformers with HiLo Attention. [Paper] [Code]
Name | Params (M) | FLOPs (M) | CPU Speed | GPU Speed | Demo |
---|---|---|---|---|---|
MSA | 2.36 | 521.43 | 505 | 4403 | msa.py |
Cross Window | 2.37 | 493.28 | 325 | 4334 | cross_window.py |
DAT | 2.38 | 528.69 | 223 | 3074 | dat.py |
Performer | 2.36 | 617.24 | 181 | 3180 | performer.py |
Linformer | 2.46 | 616.56 | 518 | 4578 | linformer |
SRA | 4.72 | 419.56 | 710 | 4810 | sra.py |
Local Window | 2.36 | 477.17 | 631 | 4537 | shifted_window.py |
Shifted Window | 2.36 | 477.17 | 374 | 4351 | shifted_window.py |
Focal | 2.44 | 526.85 | 146 | 2842 | focal.py |
XCA | 2.36 | 481.69 | 583 | 4659 | xca.py |
QuadTree | 5.33 | 613.25 | 72 | 3978 | quadtree.py |
VAN | 1.83 | 357.96 | 59 | 4213 | van.py |
HorNet | 2.23 | 436.51 | 132 | 3996 | hornet.py |
HiLo | 2.20 | 298.30 | 1029 | 5104 | hilo.py |
Note: Each method has its own hyperparameters. For a fair comparison on 1/16 scale feature maps, all methods in the above table adopt their default 1/16 scale settings, as shown in their released code repo. For example, when dealing with 1/16 scale feature maps, HiLo in LITv2 adopt a window size of 2 and alpha of 0.9. Future works will consider more scales and memory benchmarking.
This repository is released under the Apache 2.0 license as found in the LICENSE file.