The official code repository for "COAT: Measuring Object Compositionality in Emergent Representations"
Sirui Xie,
Ari Morcos,
Song-Chun Zhu,
Ramakrishna Vedantam
Presented at ICML 2022.
- Python >= 3.8
- PyTorch >= 1.7.1
- Pytorch Lightning == 1.1.4
- hydra-core == 1.2.0
- tqdm
- CUDA enabled computing device
This repository contains
- The generation code for COAT testing corpus modified based on the CLEVR generation code
- The generation code for Correlated CLEVR with colorful background based on the CLEVR generation code
- Pytorch implementation of Slot Attention and beta-TC-VAE, modified based on repositories from Untitled-AI and AntixK respectively. The modification on Slot Attention mainly concerns the post-processing of deduplicating slots (controlled by
dup_threshold
) and removing invisible slots, i.e. slots with close-to-zero mask weight (controlled byrm_invisible
). - The method
validation_epoch_end
inmethod.py
for applying the COAT metric to slot-based representations and slot-free representations.
The generated training and test data is available at here. You probably should change the following data paths in the configuration files in ./hydra_cfg/
:
data_mix_idx: 1 # the index of data mixture, check data_mix.csv for details
data_mix_csv: /your_data_root/data_mix.csv # the file for different composition of the training set
data_root: /your_data_root/clevr_corr/ # training data for both iid and correlated CLEVR with colorful background
val_root: /your_data_root/clevr_with_masks/ # evaluation data from original CLEVR for mask ARI metric
test_root: /your_data_root/coat_test/ # testing data for our COAT metric
In the training data we provide, data_mix.csv
is a meta file for different composition of training sets with different correlations. Set data_mix_idx=1,2,3,4,5
for i.i.d. dataset; set data_mix_idx=13,14,15,16,17
for the correlated dataset in our paper.
To generate the test or the training data, check ./coat_generation/
.
Our COAT measure can be expanded to domains other than CLEVR, the Dataset class CLEVRAlgebraTestset
in data.py
is reusable. It takes a list of tuples of images as input test_cases: List[List[Optional[str]]]
. In train_hydra.py
, such a list is loaded from /test_root/obj_test_final/CLEVR_test_cases_hard.csv
, which is contained in our released data, with the following code:
if os.path.exists(os.path.join(cfg.test_root, "obj_test_final", "CLEVR_test_cases_hard.csv")):
with open(os.path.join(cfg.test_root, "obj_test_final", "CLEVR_test_cases_hard.csv"), "r") as f:
csv_reader = reader(f)
self.obj_algebra_test_cases = list(csv_reader)
else:
self.obj_algebra_test_cases = None
print(os.path.join(cfg.test_root, "obj_test_final", "CLEVR_test_cases_hard.csv")+" does not exist.")
Configuration files for models and training can be found in ./hydra_cfg/
, and should be linked to hydra_train.py
with
@hydra.main(config_path='hydra_cfg', config_name='cfg_file')
""" 'cfg_file' can be either of
- 'bvae' for beta-tc-vae
- 'slot-attn' for original slot attention model
- 'slot-attn-no-dup' for slot attention model with duplicated slots removed
- 'slot-attn-no-dup-no-inv' for slot attention model with duplicated slots and invisible slots removed.
Details of different post process on the representations can be found in the paper.
"""
To train models from scratch with epoch-wise COAT test, run
python hydra_train.py
We use wandb to output logging. Logs should contain COAT metrics and test visualization.
The COAT metrics include the COAT-l2 and the COAT-acos scores which are normalized and corrected for chance, as well as the empirical probability of P(Loss(A, B, C, D)<Loss(A, B, C, D')), where D' is the hard negative. Here are some example training curves.
The visualization shows how well the models reconstructs the images, as well as how well the slots are matched for Slot Attention.
Here are some examples.
@inproceedings{xie2022coat,
title={COAT: Measuring Object Compositionality in Emergent Representations},
author={Xie, Sirui and Morcos, Ari S and Zhu, Song-Chun and Vedantam, Ramakrishna},
booktitle={International Conference on Machine Learning},
pages={24388--24413},
year={2022},
organization={PMLR}
}