- Propose a
Balanced BatchNorm
module to alleviate the estimation bias towards local and global class-imbalanced test data. The normal BN module is subject to estimation bias when the batch data volume is small or when the batch data has class bias. - Propose a
Tri-Net architecture
to reduce the risk of over-adapting to local distributions and provide a stable and robust TTA procedure. - Combining the above two modules, TRIBE can handle different kinds of test data streams,
either i.i.d. test stream or local and global class imbalance test stream
, and adapt the model to the target domain stably and safely.
conda create -n tribe python=3.9.0
conda activate tribe
# install pip and dependencies for the fresh python
conda install -y ipython pip
# install required packages
pip install -r .
# install robustbench
cd robustbench
pip install .
cd -
# build Balanced BN (optional for faster inference speed)
cd cpp_wrapper/balanced_bn
pip install .
cd -
Attention to somebody cannot build the cpp version of Balanced BN due to no root privilege or other reasons, the python version of Balanced BN implementation has been released at the files core/utils/balanced_bn_pyv.py
. Contributed to the use of parallel operations, e.g. torch.scatter_add_
, the python version of Balanced BN implementation also enjoy the fast inference speed. Therefore, you don't have to worry that you cannot compile the cpp_wrapper/balanced_bn
. Sorry for delayed.
Download CIFAR-10-C, CIFAR-100-C and ImageNet-C. (Running the code directly also works, since it automatically downloads the data set at the first running, but it's too slow to tolerate and has high requirements on internet stability)
Symlink dataset by
ln -s path_to_cifar10_c datasets/CIFAR-10-C
ln -s path_to_cifar100_c datasets/CIFAR-100-C
ln -s path_to_imagenet_c datasets/ImageNet-C
python GLI_TTA.py \
-acfg configs/adapter/TRIBE.yaml \
-dcfg configs/dataset/cifar10.yaml \
-pcfg configs/protocol/gli_tta.yaml \
OUTPUT_DIR TRIBE/cifar10
python GLI_TTA.py \
-acfg configs/adapter/TRIBE.yaml \
-dcfg configs/dataset/cifar100.yaml \
-pcfg configs/protocol/gli_tta.yaml \
OUTPUT_DIR TRIBE/cifar100
python GLI_TTA.py \
-acfg configs/adapter/TRIBE.yaml \
-dcfg configs/dataset/imagenet.yaml \
-pcfg configs/protocol/gli_tta.yaml \
OUTPUT_DIR TRIBE/imagenet
Hint: The hyper-parameters may be modified in ./configs/adapter/TRIBE.yaml
, and please modify them according to the suggestions written into the file.
Apart from the TRIBE implementation, this repo has also implemented multiple mainstream TTA algorithms and TTA protocols so that you can reproduce their results simply by modifying the running command. Algorithms include BN
, PL
, TENT
, LAME
, EATA
, NOTE
, TTAC
(without queue), COTTA
, PETAL
and ROTTA
. TTA protocols include Single Domain TTA
, Continual TTA
, Gradual Changing Continual TTA
, PTTA
(proposed in ROTTA) and GLI TTA
(proposed in this paper).
For example:
if we want to run ROTTA
under Continual TTA
protocol, we can run:
python GLI_TTA.py \
-acfg configs/adapter/rotta.yaml \
-dcfg configs/dataset/cifar10.yaml \
-pcfg configs/protocol/continual_tta.yaml \
OUTPUT_DIR ROTTA/cifar10
Or run TRIBE
under Gradual Changing Continual TTA
protocol, as
python GLI_TTA.py \
-acfg configs/adapter/TRIBE.yaml \
-dcfg configs/dataset/gradualCifar10.yaml \
-pcfg configs/protocol/continual_tta.yaml \
OUTPUT_DIR TRIBE/cifar10
Or under Single Domain TTA
protocol, here need to modify the CORRUPTION.TYPE to one specific domain in ./configs/dataset/cifar10.yaml
and run:
python GLI_TTA.py \
-acfg configs/adapter/TRIBE.yaml \
-dcfg configs/dataset/cifar10.yaml \
-pcfg configs/protocol/continual_tta.yaml \
OUTPUT_DIR TRIBE/cifar10
Or under PTTA
protocol, as
python GLI_TTA.py \
-acfg configs/adapter/TRIBE.yaml \
-dcfg configs/dataset/cifar10.yaml \
-pcfg configs/protocol/ptta.yaml \
OUTPUT_DIR TRIBE/cifar10
In addition to the above simple switching configurations, we can also make fine adjustments in different profiles, such as adjusting different category imbalance ratios in GLI-TTA protocols, as
LOADER:
SAMPLER:
TYPE: "gli_tta"
IMB_FACTOR: 10 # global imbalance factor: 10, 100, 200
CLASS_RATIO: "constant" # "constant" for GLI-TTA-F or "random" for GLI-TTA-V
GAMMA: 0.1 # local imbalance factor: 10, 1.0, 0.1, 0.01, 0.001
This project is based on the following open-source projects: rotta. We thank their authors for making the source code publicly available.
If you find our work useful in your research, please consider citing:
@misc{su2023realworld,
title={Towards Real-World Test-Time Adaptation: Tri-Net Self-Training with Balanced Normalization},
author={Yongyi Su and Xun Xu and Kui Jia},
year={2023},
eprint={2309.14949},
archivePrefix={arXiv},
primaryClass={cs.LG}
}