Skip to content

A Domain-Agnostic Benchmark for Self-Supervised Learning

License

Notifications You must be signed in to change notification settings

alextamkin/dabs

Repository files navigation

DABS: A Domain Agnostic Benchmark for Self-Supervised Learning

This repository contains the code for DABS, a benchmark for domain-agnostic self-supervised learning algorithms. The basic components of the benchmark can be found in datasets, encoders, and algorithms. Training is implemented with the PyTorch Lightning framework, logging with Weights and Biases, and configuration management with Hydra.

Updates

March 2023

  • Updated with DABS 2.0 additions:
  • 5 new domains: Proteins, Semiconductors, Particle Physics, Multispectral Satellite, Bacterial Genomics
  • 2 new algorithms: Masked Autoencoding (mae) and Capri / Contrastive Prediction (contpred)
  • Variable corruption rates via the --corruption_rate flag (e.g. what fraction of input tokens/patches are masked out for MAE, or shuffled for ShED)

Jan 2023

  • Fixed a bug in the Captioned Images domain (MSCOCO and VQA). Please rerun pretraining and transfer for this domain if you've used it previously. Numbers for this domain have been updated in the repo and paper.
  • Incorporated a patch for CVE-2007-4559, a vulnerability in Python's tarfile package.

Usage

We provide support for Python >= 3.8. Install requirements with

python -m pip install -r requirements.txt

For instructions on how to install PyTorch versions compatible with your CUDA versions, see pytorch.org. We support Torch 1.6.0 but later versions may work as well.

Datasets

We provide a set of dataset implementations (in src/datasets) from image, text, speech, sensor, medical imaging, and image-text domains. Preprocessing operations on these datasets are minimal and hard-coded as simple resizing (i.e. of images) and truncations (i.e. of text, audio). These should not be changed so as to maintain fair comparisons across other users of the benchmark.

See conf/datasets/*.yaml for all dataset configs, including the loss, metrics, and batch size used for each dataset.

Almost all datasets will download automatically when the dataset class is instantiated. The exceptions are the CheXpert, ImageNet, and CU Birds datasets, where manual registration or download is required. See the respective dataset files for specific instructions.

Pretraining Dataset (unlabeled) Transfer Dataset (labeled)
ImageNet Aircraft, CIFAR10, CU Birds, DTD, Traffic Sign, VGG Flower
PAMAP2 PAMAP2
MSCOCO MSCOCO (mismatched detection), VQA (Binary classification)
Wikitext-103 GLUE (10 Tasks)
mC4 PAWS-X (7 Tasks)
CheXpert CheXpert (atelectasis, cardiomegaly, consolidation, edema, and pleural effusion), ChestX-ray8 (atelectasis, cardiomegaly, effusion, infiltration, mass, nodule, pneumonia, pneumothorax)
LibriSpeech Audio MNIST, Fluent Speech (Action, Object, Location), Google Speech Commands, LibriSpeech, VoxCeleb1

Pretraining

During the pretraining phase, self-supervised encoders are trained to learn good representations from unlabeled data. We currently support seven datasets for pretraining, one for each domain: MS COCO, ImageNet, CheXpert, PAMAP2, mC4, WikiText-103, and LibriSpeech. If the pretraining dataset has associated labels, an online linear evaluator is jointly trained with the encoder to provide a heuristic of transfer performance.

Run pretraining with commands like

python pretrain.py exp.name=<experiment-name> dataset=<dataset> algorithm=<algorithm>

Each dataset and encoder has its own config file, so to train a Transformer on the CheXpert dataset with the e-Mix algorithm, run

python pretrain.py exp.name=emix-chexpert encoder=transformer dataset=chexpert algorithm=emix

See conf/pretrain.yaml for all pretraining configuration fields.

For more information on the datasets, encoders, and algorithms, see the following section.

Pretraining Dataset Modality Label type (unused) Input Type
CIFAR10 Natural images Single label 2d
PAMAP2 Sensor Single l