This repository contains the code for DABS, a benchmark for domain-agnostic self-supervised learning algorithms. The basic components of the benchmark can be found in datasets, encoders, and algorithms. Training is implemented with the PyTorch Lightning framework, logging with Weights and Biases, and configuration management with Hydra.
March 2023
- Updated with DABS 2.0 additions:
- 5 new domains: Proteins, Semiconductors, Particle Physics, Multispectral Satellite, Bacterial Genomics
- 2 new algorithms: Masked Autoencoding (
mae
) and Capri / Contrastive Prediction (contpred
) - Variable corruption rates via the
--corruption_rate
flag (e.g. what fraction of input tokens/patches are masked out for MAE, or shuffled for ShED)
Jan 2023
- Fixed a bug in the Captioned Images domain (MSCOCO and VQA). Please rerun pretraining and transfer for this domain if you've used it previously. Numbers for this domain have been updated in the repo and paper.
- Incorporated a patch for CVE-2007-4559, a vulnerability in Python's tarfile package.
We provide support for Python >= 3.8. Install requirements with
python -m pip install -r requirements.txt
For instructions on how to install PyTorch versions compatible with your CUDA versions, see pytorch.org. We support Torch 1.6.0 but later versions may work as well.
We provide a set of dataset implementations (in src/datasets
) from image, text, speech, sensor, medical imaging, and image-text domains. Preprocessing operations on these datasets are minimal and hard-coded as simple resizing (i.e. of images) and truncations (i.e. of text, audio). These should not be changed so as to maintain fair comparisons across other users of the benchmark.
See
conf/datasets/*.yaml
for all dataset configs, including the loss, metrics, and batch size used for each dataset.
Almost all datasets will download automatically when the dataset class is instantiated. The exceptions are the CheXpert, ImageNet, and CU Birds datasets, where manual registration or download is required. See the respective dataset files for specific instructions.
Pretraining Dataset (unlabeled) | Transfer Dataset (labeled) |
---|---|
ImageNet | Aircraft, CIFAR10, CU Birds, DTD, Traffic Sign, VGG Flower |
PAMAP2 | PAMAP2 |
MSCOCO | MSCOCO (mismatched detection), VQA (Binary classification) |
Wikitext-103 | GLUE (10 Tasks) |
mC4 | PAWS-X (7 Tasks) |
CheXpert | CheXpert (atelectasis, cardiomegaly, consolidation, edema, and pleural effusion), ChestX-ray8 (atelectasis, cardiomegaly, effusion, infiltration, mass, nodule, pneumonia, pneumothorax) |
LibriSpeech | Audio MNIST, Fluent Speech (Action, Object, Location), Google Speech Commands, LibriSpeech, VoxCeleb1 |
During the pretraining phase, self-supervised encoders are trained to learn good representations from unlabeled data. We currently support seven datasets for pretraining, one for each domain: MS COCO, ImageNet, CheXpert, PAMAP2, mC4, WikiText-103, and LibriSpeech. If the pretraining dataset has associated labels, an online linear evaluator is jointly trained with the encoder to provide a heuristic of transfer performance.
Run pretraining with commands like
python pretrain.py exp.name=<experiment-name> dataset=<dataset> algorithm=<algorithm>
Each dataset and encoder has its own config file, so to train a Transformer on the CheXpert dataset with the e-Mix algorithm, run
python pretrain.py exp.name=emix-chexpert encoder=transformer dataset=chexpert algorithm=emix
See
conf/pretrain.yaml
for all pretraining configuration fields.
For more information on the datasets, encoders, and algorithms, see the following section.
Pretraining Dataset | Modality | Label type (unused) | Input Type |
---|---|---|---|
CIFAR10 | Natural images | Single label | 2d |
PAMAP2 | Sensor | Single l |