-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Yuge Shi <[email protected]> Co-authored-by: Brooks Paige <[email protected]>
- Loading branch information
Showing
35 changed files
with
121,217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
((lua-mode . ((lua-indent-level . 2))) | ||
(python-mode . ((tab-width . 2)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# https://github.com/pytorch/pytorch/blob/d0db624e02951c4dd6eb6b21d051f7ccf8133707/setup.cfg | ||
[flake8] | ||
max-line-length = 120 | ||
ignore = E302,E305,E402,E721,E731,F401,F403,F405,F811,F812,F821,F841,W503 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
.* | ||
**/*~ | ||
**/_* | ||
**/auto | ||
**/*.aux | ||
**/*.bbl | ||
**/*.blg | ||
**/*.log | ||
**/*.out | ||
**/*.old | ||
**/*.run.xml | ||
**/images/ | ||
*.pyc | ||
**/__pycache__ | ||
!__init__.py | ||
!_imgs | ||
|
||
data/ | ||
experiments/**/ | ||
/.bash_history | ||
|
||
bin/*.sh | ||
bin/*.png | ||
bin/face_extract_vgg/ | ||
|
||
doc/ | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
<img src="_imgs/schematic.png" width="200" height="200" align="right"> | ||
|
||
# Multimodal Mixture-of-Experts VAE | ||
This repository contains the code for the framework in **Variational Mixture-of-Experts Autoencodersfor Multi-Modal Deep Generative Models** (see [paper](https://arxiv.org/pdf/1911.03393.pdf)). | ||
|
||
## Requirements | ||
List of packages we used and the version we tested the model on (see also `requirements.txt`) | ||
|
||
``` | ||
python == 3.6.8 | ||
gensim == 3.8.1 | ||
matplotlib == 3.1.1 | ||
nltk == 3.4.5 | ||
numpy == 1.16.4 | ||
pandas == 0.25.3 | ||
scipy == 1.3.2 | ||
seaborn == 0.9.0 | ||
scikit-image == 0.15.0 | ||
torch == 1.3.1 | ||
torchnet == 0.0.4 | ||
torchvision == 0.4.2 | ||
umap-learn == 0.1.1 | ||
``` | ||
|
||
## Downloads | ||
### MNIST-SVHN Dataset | ||
|
||
<p><img src="_imgs/mnist-svhn.png" width=150 align="right"></p> | ||
|
||
We construct a dataset of pairs of MNIST and SVHN such that each pair depicts the same digit class. Each instance of a digit class in either dataset is randomly paired with 20 instances of the same digit class from the other dataset. | ||
|
||
**Usage**: To prepare this dataset, run `bin/make-mnist-svhn-idx.py` -- this should automatically handle the download and pairing. | ||
|
||
### CUB Image-Caption | ||
|
||
<p><img src="_imgs/cub.png" width=200 align="right"></p> | ||
|
||
We use Caltech-UCSD Birds (CUB) dataset, with the bird images and their captions serving as two modalities. | ||
|
||
**Usage**: We offer a cleaned-up version of the CUB dataset. Download the dataset [here](https://www.robots.ox.ac.uk/~yshi/mmdgm/datasets/cub.zip). First, create a `data` folder under the project directory; then unzip thedownloaded content into `data`. After finishing these steps, the structure of the `data/cub` folder should look like: | ||
|
||
``` | ||
data/cub | ||
│───text_testclasses.txt | ||
│───text_trainvalclasses.txt | ||
│───train | ||
│ │───002.Laysan_Albatross | ||
│ │ └───...jpg | ||
│ │───003.Sooty_Albatross | ||
│ │ └───...jpg | ||
│ │───... | ||
│ └───200.Common_Yellowthroat | ||
│ └───...jpg | ||
└───test | ||
│───001.Black_footed_Albatross | ||
│ └───...jpg | ||
│───004.Groove_billed_Ani | ||
│ └───...jpg | ||
│───... | ||
└───197.Marsh_Wren | ||
└───...jpg | ||
``` | ||
|
||
|
||
### Pretrained network | ||
Pretrained models are also available if you want to play around with it. Download from the following links: | ||
- [MNIST-SVHN](https://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/mnist-svhn.zip) | ||
- [CUB Image-Caption (feature)](https://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/cubISft.zip) | ||
- [CUB Image-Caption (raw images)](https://www.robots.ox.ac.uk/~yshi/mmdgm/pretrained_models/cubIS.zip) | ||
|
||
## Usage | ||
|
||
### Training | ||
|
||
Make sure the [requirements](#requirements) are satisfied in your environment, and relevant [datasets](#downloads) are downloaded. `cd` into `src`, and, for MNIST-SVHN experiments, run | ||
|
||
```bash | ||
python main.py --model mnist_svhn | ||
|
||
``` | ||
|
||
For CUB Image-Caption with image feature search (See Figure 7 in our [paper](https://arxiv.org/pdf/1911.03393.pdf)), run | ||
```bash | ||
python main.py --model cubISft | ||
|
||
``` | ||
|
||
For CUB Image-Caption with raw image generation, run | ||
```bash | ||
python main.py --model cubIS | ||
|
||
``` | ||
|
||
You can also play with the hyperparameters using arguments. Some of the more interesting ones are listed as follows: | ||
- **`--obj`**: Objective functions, offers 3 choices including importance-sampled ELBO (`elbo`), IWAE (`iwae`) and DReG (`dreg`, used in paper). Including the `--looser` flag when using IWAE or DReG removes unbalanced weighting of modalities, which we find to perform better empirically; | ||
- **`--K`**: Number of particles, controls the number of particles `K` in IWAE/DReG estimator, as specified in following equation: | ||
|
||
<p align='center'><img src="_imgs/obj.png"></p> | ||
|
||
- **`--learn-prior`**: Prior variance learning, controls whether to enable prior variance learning. Results in our paper are produced with this enabled. Excluding this argument in the command will disable this option; | ||
- **`--llik_scaling`**: Likelihood scaling, specifies the likelihood scaling of one of the two modalities, so that the likelihoods of two modalities contribute similarly to the lower bound. The default values are: | ||
- _MNIST-SVHN_: MNIST scaling factor 32*32*3/28*28*1 = 3.92 | ||
- _CUB Image-Cpation_: Image scaling factor 32/64*64*3 = 0.0026 | ||
- **`--latent-dimension`**: Latent dimension | ||
|
||
You can also load from pre-trained models by specifying the path to the model folder, for example `python --model mnist_svhn --pre-trained path/to/model/folder/`. | ||
|
||
### Analysing | ||
We offer tools to reproduce the quantitative results in our paper in `src/report`. To run any of the provided script, `cd` into `src`, and | ||
|
||
- for likelihood estimation of data using trained model, run `python calculate_likelihoods.py --save-dir path/to/trained/model/folder/ --iwae-samples 1000`; | ||
- for coherence analysis and latent digit classification accuracy on MNIST-SVHN dataset, run `python analyse_ms.py --save-dir path/to/trained/model/folder/`; | ||
- for coherence analysis on CUB image-caption dataset, run `python analyse_cub.py --save-dir path/to/trained/model/folder/`. | ||
|
||
|
||
## Contact | ||
If you have any question, feel free to create an issue or email Yuge Shi at [email protected]. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
from torchvision import datasets, transforms | ||
|
||
def rand_match_on_idx(l1, idx1, l2, idx2, max_d=10000, dm=10): | ||
""" | ||
l*: sorted labels | ||
idx*: indices of sorted labels in original list | ||
""" | ||
_idx1, _idx2 = [], [] | ||
for l in l1.unique(): # assuming both have same idxs | ||
l_idx1, l_idx2 = idx1[l1 == l], idx2[l2 == l] | ||
n = min(l_idx1.size(0), l_idx2.size(0), max_d) | ||
l_idx1, l_idx2 = l_idx1[:n], l_idx2[:n] | ||
for _ in range(dm): | ||
_idx1.append(l_idx1[torch.randperm(n)]) | ||
_idx2.append(l_idx2[torch.randperm(n)]) | ||
return torch.cat(_idx1), torch.cat(_idx2) | ||
|
||
if __name__ == '__main__': | ||
max_d = 10000 # maximum number of datapoints per class | ||
dm = 30 # data multiplier: random permutations to match | ||
|
||
# get the individual datasets | ||
tx = transforms.ToTensor() | ||
train_mnist = datasets.MNIST('../data', train=True, download=True, transform=tx) | ||
test_mnist = datasets.MNIST('../data', train=False, download=True, transform=tx) | ||
train_svhn = datasets.SVHN('../data', split='train', download=True, transform=tx) | ||
test_svhn = datasets.SVHN('../data', split='test', download=True, transform=tx) | ||
# svhn labels need extra work | ||
train_svhn.labels = torch.LongTensor(train_svhn.labels.squeeze().astype(int)) % 10 | ||
test_svhn.labels = torch.LongTensor(test_svhn.labels.squeeze().astype(int)) % 10 | ||
|
||
mnist_l, mnist_li = train_mnist.targets.sort() | ||
svhn_l, svhn_li = train_svhn.labels.sort() | ||
idx1, idx2 = rand_match_on_idx(mnist_l, mnist_li, svhn_l, svhn_li, max_d=max_d, dm=dm) | ||
print('len train idx:', len(idx1), len(idx2)) | ||
torch.save(idx1, '../data/train-ms-mnist-idx.pt') | ||
torch.save(idx2, '../data/train-ms-svhn-idx.pt') | ||
|
||
mnist_l, mnist_li = test_mnist.targets.sort() | ||
svhn_l, svhn_li = test_svhn.labels.sort() | ||
idx1, idx2 = rand_match_on_idx(mnist_l, mnist_li, svhn_l, svhn_li, max_d=max_d, dm=dm) | ||
print('len test idx:', len(idx1), len(idx2)) | ||
torch.save(idx1, '../data/test-ms-mnist-idx.pt') | ||
torch.save(idx2, '../data/test-ms-svhn-idx.pt') |
Oops, something went wrong.