This is the official PyTorch Implementation of our upcoming BMVC 2024 PatchRot paper "PatchRot: Self-Supervised Training of Vision Transformers by Rotation Prediction".
PatchRot rotates images and image patches and trains the network to predict the rotation angles.
The network learns to extract global image and patch-level features through this process.
PatchRot pretraining extracts superior features and provides improved performance.
Run commands (also available in run_cifar10.sh):
Run main_pretrain.py to pre-train the network with PatchRot, followed by main_finetune.py --init patchrot to finetune the network.
main_finetune.py --init none can be used to train the network without any pretraining (training from random initialization).
Below is an example on CIFAR10:
Method | Run Command |
---|---|
PatchRot pretraining | python main_pretrain.py --dataset cifar10 |
Finetuning pretrained model | python main_finetune.py --dataset cifar10 --init patchrot |
Training from random init | python main_finetune.py --dataset cifar10 --init none |
Replace cifar10 with the appropriate dataset.
Supported datasets: CIFAR10, CIFAR100, FashionMNIST, SVHN, TinyImageNet, Animals10N, and ImageNet100.
CIFAR10, CIFAR100, FashionMNIST, and SVHN datasets will be downloaded to the path specified in the "data_path" argument (default: "./data").
TinyImageNet, Animals10N, and ImageNet100 need to be downloaded, and the path needs to be provided using the "data_path" argument.
Dataset | Without PatchRot Pretraining | With PatchRot Pretraining |
---|---|---|
CIFAR10 | 84.4 | 91.3 |
CIFAR100 | 56.5 | 66.7 |
FashionMNIST | 93.4 | 94.6 |
SVHN | 92.9 | 96.4 |
Animals10N | 69.6 | 79.5 |
TinyImageNet | 38.4 | 48.8 |
ImageNet100 | 64.6 | 75.4 |