This is the official Python
implementation of the NeurIPS 2023 paper Extremal Domain Translation with Neural Optimal Transport (paper page on NeurIPS) by Milena Gazdieva, Alexander Korotin, Daniil Selikhanovych and Evgeny Burnaev.
The repository contains reproducible PyTorch
source code for computing incomplete transport (IT) maps in high dimensions with neural networks. The algorithm can be used to partially align distributions or to approximate extremal (ET) transport maps. Examples are provided for toy problems (2D) and for the unpaired image-to-image translation task for various pairs of datasets.
- Talk by Milena Gazdieva at Fall into ML conference (28 October 2023, EN)
- Talk by Milena Gazdieva at AIRI meetup (21 December 2023, RU)
- Repository for Neural Optimal Transport paper (ICLR 2023).
@inproceedings{
gazdieva2023extremal,
title={Extremal Domain Translation with Neural Optimal Tansport},
author={Gazdieva, Milena and Korotin, Alexander and Selikhanovych, Daniil and Burnaev, Evgeny},
booktitle={Neural Information Processing and Systems},
year={2023}
}
The unpaired domain translation task can be posed as a classic OT problem. The corresponding OT maps (or plans) generally preserve certain image attributes during the translation due to the problem nature. However, OT problem formulation leads to failures in certain cases, e.g., when the attributes of objects from source and target distributions are not balanced.
Our IT algorithm could be used to resolve this issue. It searches for a transport map with the minimal transport cost (e.g.,
In contrast to the other popular image-to-image translation models based on GANs or diffusion models, our method provides the following key advantages
- it can be used to control the similarity between source and translated images (
$w$ ); - it is theoretically justified.
Qualitative examples are shown below for various pairs of datasets (at resolutions
We test unpaired translation with XNOT with the
The implementation is GPU-based with the multi-GPU support. Tested with torch==1.9.0
and 1-4 Tesla V100.
Toy experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/
). For convenience, the majority of the evaluation output is preserved. Auxilary source code is moved to .py
modules (src/
).
-
XNOT_translation.py
- unpaired image-to-image translation ($\ell_2^2$ cost); -
notebooks/XNOT_toy.ipynb
- toy experiments in 2D (Wi-Fi, Accept); -
notebooks/XNOT_swiss2circle.ipynb
- toy experiment with analytically known ET in 2D; -
notebooks/XNOT_limitations.ipynb
- toy illustrations of limitations in 2D (Diversity, Intersection); -
notebooks/XNOT_kernel_cost.ipynb
- toy experiment with weak kernel cost in 2D (one-to-many); -
notebooks/XNOT_test.ipynb
- testing and plotting the translation results (pre-trained models are needed); -
stats/compute_stats.ipynb
- pre-compute InceptionV3 statistics to speed up test FID computation;
- Aligned anime faces (105GB) should be pre-processed with
datasets/preprocess.ipynb
; - CelebA faces requires
datasets/list_attr_celeba.ipynb
; - Handbags, shoes datasets;
- Describable Textures Dataset (DTD);
- Flickr-Faces-HQ Dataset (FFHQ);
- Comic faces;
- Bonn Furniture Styles Dataset.
The dataloaders can be created by load_dataset
function from src/tools.py
. The latter four datasets get loaded directly to RAM.
- Weights & Biases developer tools for machine learning;
- pytorch-fid repo to compute FID score;
- UNet architecture for transporter network;
- ResNet architectures for generator and discriminator;
- Inkscape for the awesome editor for vector graphics.