Skip to content

This repository includes official implementation and model weights of Data-Efficient Multi-Scale Fusion Vision Transformer.

License

Notifications You must be signed in to change notification settings

visresearch/dems

Repository files navigation

Data-Efficient Multi-Scale Fusion Vision Transformer

This repository includes official implementation and model weights of Data-Efficient Multi-Scale Fusion Vision Transformer.

1. Abstract

Vision transformer (ViT) demonstrates significant potential in image classification with massive data, but struggles with small-scale datasets. To this end, this paper proposes to address this data inefficiency by introducing multi-scale tokens, which provides the image prior of multiple scales and enables learning scale-invariant features. Our model generates tokens of varying scales from images using different patch sizes, where each token of the larger scale is linked to a set of tokens of other smaller scales based on spatial correspondences. Through a regional cross-scale interaction module, tokens of different scales fuse regionally to enhance the learning of local structures.Additionally, we implement a data augmentation schedule to refine training. Extensive experiments on image classification demonstrate our approach surpasses DeiT and other multi-scale transformer methods on small-scale datasets.

  • Multi-Scale Tokenization
  • Regional Cross-Scale Interaction

2. Requirements

To install requirements:

conda create -n dems python=3.8
pip install -r requirements.txt

3. Datasets

The root paths of data are set to /path/to/dataset. Please set the root paths accordingly. CIFAR10, CIFAR100, FashionMNIST, EMNIST datasets provided by torchvision. Download and extract Caltech101 train and val images from https://www.vision.caltech.edu/datasets/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val/ folder respectively.

4. Training

Set hyperparameters and GPU IDs in ./config/pretrain/dems_small_pretrain.py. Run the following command to train DEMS-ViT-S on CIFAR100 for 800 epochs, with random initialization on a single node with multiple gpus:

python main_pretrain --model dems_small --batch_size 256 --epochs 800 --dataset CIFAR100 --data_path /path/to/CIFAR100

5. Fine-tuning

Set hyperparameters and GPU IDs in ./config/pretrain/dems_small_finetune.py. Run the following command to finetune DEMS-ViT-S on CIFAR100 for 100 epochs:

python main_finetune --model dems_small --batch_size 256 --epochs 100 --dataset CIFAR100 --data_path /path/to/CIFAR100 --pretrained_weight /path/pretrained

6. Main Results and Model Weights

6.1 Pretrained weights

We provide models trained on CIFAR, EMNIST, FASHIONNIST, and CALTECH101 here. Particularly, we train on CALTECH101 with the input size of 256x256 and patch size of 16.

Name #FLOPs #Params Dataset Acc@1 URL
DEMS-ViT-Ti 1.6 5.6M CIFAR10
CIFAR100
FASHIONMNIST
EMNIST
CALTECH101
96.03
80.60
95.59
99.56
86.56
model
model
model
model
model
DEMS-ViT-S 5.8 22.3M CIFAR10
CIFAR100
FASHIONMNIST
EMNIST
CALTECH101
96.20
83.30
95.99
99.58
86.88
model
model
model
model
model

6.2 Fine-tuned weights

We provide fine-tuned models on CIFAR, which can be found here.

Name Dataset Acc@1 URL
DEMS-ViT-Ti CIFAR10
CIFAR100
96.74
83.50
model
model
DEMS-ViT-S CIFAR10
CIFAR100
97.76
85.16
model
model

7. License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

About

This repository includes official implementation and model weights of Data-Efficient Multi-Scale Fusion Vision Transformer.

Topics

Resources

License

Stars

Watchers

Forks

Languages