This is the code for the Conditional Mutual Information-Debiasing (CMID) method proposed in the paper Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness by Bhavya Vasudeva, Kameron Shahabi and Vatsal Sharan. (The base code comes from the group_DRO
implementation.)
The code uses python 3.6.8
. Dependencies can be installed by using:
pip install -r requirements.txt
Change the root_dir
variable in data/data.py
. Datasets will be stored in the location specified by root_dir
. (Check this link for more details.)
Experiments on Waterbirds, CelebA, MultiNLI, and CivilComments datasets.
-
Waterbirds: The code expects the following files/folders in the
[root_dir]/cub
directory:data/waterbird_complete95_forest2water2/
A tarball of this dataset can be downloaded from this link.
-
CelebA: The code expects the following files/folders in the
[root_dir]/celebA
directory:data/list_eval_partition.csv
data/list_attr_celeba.csv
data/img_align_celeba/
These dataset files can be downloaded from this Kaggle link.
-
MultiNLI: The code expects the following files/folders in the
[root_dir]/multinli
directory:data/metadata_random.csv
glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli
glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli-mm
glue_data/MNLI/cached_train_bert-base-uncased_128_mnli
The metadata file is included in
dataset_metadata/multinli
in the folder. Theglue_data/MNLI
files are generated by the huggingface Transformers library and can be downloaded here. -
CivilComments: The code expects the following files/folders in the
[root_dir]/civcom
directoryall_data_with_grouped_identities.csv
all_data_with_identities.csv
A tarball of this dataset can be downloaded from this link.
The main files to run the experiment and infer results are run_expt.py
and parse_log_file.py
, respectively. The specific commands are listed below:
-
Waterbirds:
python run_expt.py --log_dir /CMID/log-wb -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.0005 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 100 --cmi_reg --log_every 20 --reg_st 20.0 --cmistinc --scale 4
python parse_log_file.py --log_dir /CMID/log-wb --num_groups 4
-
CelebA:
python run_expt.py --log_dir /CMID/log-cel -s confounder -d CelebA -t Blond_Hair -c Male --lr 0.0003 --batch_size 128 --weight_decay 0.001 --model resnet50 --n_epochs 50 --cmi_reg --log_every 20 --reg_st 10.0 --cmistinc --scale 5
python parse_log_file.py --log_dir /CMID/log-cel --num_groups 4
-
MultiNLI:
python /run_expt.py --log_dir /CMID/log-mnli -s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --lr 5e-05 --batch_size 32 --weight_decay 0 --model bert --n_epochs 5 --cmi_reg --reg_st 75.0 --cmistinc --lr1 0.005
python parse_log_file.py --log_dir /CMID/log-mnli --num_groups 6
-
CivilComments:
python run_expt.py --log_dir /CMID/log-ccom -s confounder -d CivComMod -t toxicity -c identity_any --lr 0.00001 --batch_size 32 --weight_decay 0.001 --model bert-base-uncased --n_epochs 10 --cmi_reg --reg_st 25.0 --cmistinc --lr1 0.0001
python parse_log_file.py --log_dir /CMID/log-ccom --num_groups 16
The code expects the following files/folders in the ./camelyon
directory.
data/camelyon17_v1.0/metadata.csv
data/camelyon17_v1.0/patches/
Including all the patch data. If these files do not exist, the code will download them here during run time.
We use a different file for Camelyon to use Wilds dataloading. To run it, go into the ./camelyon
directory and run the following sample command, which will output camelyon.txt
in the same directory containing results.
python camelyon.py --cmi_reg --epochs 5 --epochs2 10 --lr 0.0001 --lr1 0.0001 --weight_decay 0.01 --reg_st 0.5 --batch_size 32 &> camelyon.txt
If you find our research useful, please cite our work.
@misc{vasudeva2023mitigating,
title={Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness},
author={Bhavya Vasudeva and Kameron Shahabi and Vatsal Sharan},
year={2023},
eprint={2310.06161},
archivePrefix={arXiv},
primaryClass={cs.LG}
}