This repository provides a reproducibility implementation of the paper in PyTorch.
- White-Box Transformers via Sparse Rate Reduction [NeurIPS-2023, paper link]. By Yaodong Yu (UC Berkeley), Sam Buchanan (TTIC), Druv Pai (UC Berkeley), Tianzhe Chu (UC Berkeley), Ziyang Wu (UC Berkeley), Shengbang Tong (UC Berkeley), Benjamin D Haeffele (Johns Hopkins University), and Yi Ma (UC Berkeley).
We pretrained the model using ImageNet-100, downloading it using the kaggle API command
kaggle datasets download -d ambityga/imagenet100
The included file folds need to be merged into train and val folders.
To train a CRATE model on ImageNet-1K, run the following script (training CRATE-tiny)
As an example, we use the following command for training CRATE-tiny on ImageNet-100:
python main.py
--arch {model_name}
--batch-size 512
--epochs 200
--optimizer Lion
--lr 0.0002
--weight-decay 0.05
--print-freq 25
--data DATA_DIR
and replace DATA_DIR
with [imagenet-folder with train and val folders]
.
python finetune.py
--bs 256
--net {model_name}
--opt adamW
--lr 5e-5
--n_epochs 200
--randomaug 1
--data {cifar10/cifar100/flower/pets}
--ckpt_dir CKPT_DIR
--data_dir DATA_DIR
Replace CKPT_DIR
with the path for the pretrained CRATE weight, and replace DATA_DIR
with the path for the dataset. The CKPT_DIR
could be None
, the system will automatically check the data folder to verify its presence, and if absent, it will proceed to download it.
CRATE models exhibit emergent segmentation in their self-attention maps solely through supervised training. The Colab Jupyter notebook visualize the emerged segmentations from a supervised CRATE model. The demo provides visualizations which match the segmentation figures above.
Link: re-crate-emergence.ipynb (in colab)
@article{yu2024white,
title={White-Box Transformers via Sparse Rate Reduction},
author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Haeffele, Benjamin and Ma, Yi},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}