This repository contains the training code of ViT introduced in our work: "Oscillation-free Quantization for Low-bit Vision Transformers" which has been accepted for ICML 2023. Please consider starring the repo if you find our work useful, thanks!
In this work, we discusses the issue of weight oscillation in quantization-aware training and how it negatively affects model performance. The learnable scaling factor, commonly used in quantization, was found to worsen weight oscillation. The study proposes three techniques to address this issue: statistical weight quantization (StatsQ), confidence-guided annealing (CGA), and query-key reparameterization (QKR). These techniques were tested on the ViT model and were found to improve quantization robustness and accuracy. The proposed 2-bit DeiT-T/DeiT-S algorithms outperform the previous state-of-the-art by 9.8% and 7.7%, respectively.
- numpy==1.22.3
- torch==2.0.0
- torchvision==0.15.1
- timm=0.5.4
- pyyaml
Please replace "/your/miniconda3/envs/ofq/lib/python3.8/site-packages/timm/data/dataset_factory.py" with "timm_fix_imagenet_loading_bugs/dataset_factory.py" as with the original code there is a "TypeError: init() got an unexpected keyword argument 'download'" error.
- Pretrained models will be automatically downloaded for you if set args.pretrained to True.
-
Examples of training scripts, finetuning scripts (CGA) are provided under "train_scripts/" and evaluation scripts are under "eval_scripts/" (please use the exact same batch size (batch_size * world_size) as provided in the evaluation scripts to reproduce the results reported in the paper).
-
Please modified the data path to your own dataset address
Models | #Bits | Top-1 Accuracy (Model Link) | eval script |
---|---|---|---|
DeiT-T | 32-32 | 72.02 | ------- |
OFQ DeiT-T | 2-2 | 64.33 | eval_scripts/deit_t/w2a2.sh |
OFQ DeiT-T | 3-3 | 72.72 | eval_scripts/deit_t/w3a3.sh |
OFQ DeiT-T | 4-4 | 75.46 | eval_scripts/deit_t/w4a4.sh |
DeiT-S | 32-32 | 79.9 | ------- |
OFQ DeiT-S | 2-2 | 75.72 | eval_scripts/deit_s/w2a2.sh |
OFQ DeiT-S | 3-3 | 79.57 | eval_scripts/deit_s/w3a3.sh |
OFQ DeiT-S | 4-4 | 81.10 | eval_scripts/deit_s/w4a4.sh |
Swin-T | 32-32 | 81.2 | ------- |
OFQ Swin-T | 2-2 | 78.52 | eval_scripts/swin_t/w2a2.sh |
OFQ Swin-T | 3-3 | 81.09 | eval_scripts/swin_t/w3a3.sh |
OFQ Swin-T | 4-4 | 81.88 | eval_scripts/swin_t/w4a4.sh |
The original code is borrowed from DeiT.
If you find our code useful for your research, please consider citing:
@InProceedings{pmlr-v202-liu23w,
title = {Oscillation-free Quantization for Low-bit Vision Transformers},
author = {Liu, Shih-Yang and Liu, Zechun and Cheng, Kwang-Ting},
booktitle = {Proceedings of the 40th International Conference on Machine Learning},
pages = {21813--21824},
year = {2023},
editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},
volume = {202},
series = {Proceedings of Machine Learning Research},
month = {23--29 Jul},
publisher = {PMLR},
pdf = {https://proceedings.mlr.press/v202/liu23w/liu23w.pdf},
url = {https://proceedings.mlr.press/v202/liu23w.html},
abstract = {Weight oscillation is a by-product of quantization-aware training, in which quantized weights frequently jump between two quantized levels, resulting in training instability and a sub-optimal final model. We discover that the learnable scaling factor, a widely-used $\textit{de facto}$ setting in quantization aggravates weight oscillation. In this work, we investigate the connection between learnable scaling factor and quantized weight oscillation using ViT, and we additionally find that the interdependence between quantized weights in $\textit{query}$ and $\textit{key}$ of a self-attention layer also makes ViT vulnerable to oscillation. We propose three techniques correspondingly: statistical weight quantization ($\rm StatsQ$) to improve quantization robustness compared to the prevalent learnable-scale-based method; confidence-guided annealing ($\rm CGA$) that freezes the weights with $\textit{high confidence}$ and calms the oscillating weights; and $\textit{query}$-$\textit{key}$ reparameterization ($\rm QKR$) to resolve the query-key intertwined oscillation and mitigate the resulting gradient misestimation. Extensive experiments demonstrate that our algorithms successfully abate weight oscillation and consistently achieve substantial accuracy improvement on ImageNet. Specifically, our 2-bit DeiT-T/DeiT-S surpass the previous state-of-the-art by 9.8% and 7.7%, respectively. The code is included in the supplementary material and will be released.}
}
Shih-Yang Liu, HKUST (sliuau at connect.ust.hk)