We provide code for a simple weak-to-strong experiment on ImageNet. We generate the weak labels using an AlexNet model pretrained on ImageNet and we use linear probes on top of DINO models as a strong student.
The full training command:
python3 run_weak_strong.py \
data_path: <DATA_PATH> \
weak_model_name: <WEAK_MODEL>\
strong_model_name: <STRONG_MODEL> \
batch_size <BATCH_SIZE> \
seed <SEED> \
n_epochs <N_EPOCHS> \
lr <LR> \
n_train <N_TRAIN>
Parameters:
DATA_PATH
— path to the base directory containing ImageNet data, see torchvision page for instructions; should contain filesILSVRC2012_devkit_t12.tar.gz
andILSVRC2012_img_val.tar
WEAK_MODEL
— weak model name:"alexnet"
is the only default model and the only one currently implemented
STRONG_MODEL
— weak model name:"resnet50_dino"
(default)"vitb8_dino"
BATCH_SIZE
— batch size for weak label generation and embedding extraction (default:128
)SEED
— random seed for dataset shuffling (default:0
)EPOCHS
— number of training epochs (default:10
)LR
— initial learning rate (default:1e-3
)N_TRAIN
— number of datapoints used to train the linear probe;50000 - N_TRAIN
datapoints are used as test (default:40000
)
Example commands:
# AlexNet → ResNet50 (DINO):
python3 run_weak_strong.py --strong_model_name resnet50_dino --n_epochs 20
# AlexNet → ViT-B/8 (DINO):
python3 run_weak_strong.py --strong_model_name vitb8_dino --n_epochs 5
With the commands above we get the following results (note that the results may not reproduce exactly due to randomness):
Model | Top-1 Accuracy |
---|---|
AlexNet | 56.6 |
Dino ResNet50 | 63.7 |
Dino ViT-B/8 | 74.9 |
AlexNet → DINO ResNet50 | 60.7 |
AlexNet → DINO ViT-B/8 | 64.2 |
You can add new custom models to the models.py
and new datasets to data.py
.