Continuous Wasserstein-2 Benchmark
This is the official Python
implementation of the NeurIPS 2021 paper Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark (paper on arxiv) by Alexander Korotin, Lingxiao Li, Aude Genevay, Justin Solomon, Alexander Filippov and Evgeny Burnaev.
The repository contains a set of continuous benchmark measures for testing optimal transport solvers for quadratic cost (Wasserstein-2 distance), the code for optimal transport solvers and their evaluation.
Citation
@article{korotin2021neural,
title={Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark},
author={Korotin, Alexander and Li, Lingxiao and Genevay, Aude and Solomon, Justin and Filippov, Alexander and Burnaev, Evgeny},
journal={arXiv preprint arXiv:2106.01954},
year={2021}
}
Pre-requisites
The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with
torch==1.3.0 torchvision==0.4.1
The code might not run as intended in newer torch
versions.
Related repositories
- Repository for Wasserstein-2 Generative Networks paper.
- Repository for Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization paper.
- Repository for Continuous Regularized Wasserstein Barycenters paper.
- Repository for Large-Scale Wasserstein Gradient Flows paper.
Loading Benchmark Pairs
from src import map_benchmark as mbm
# Load benchmark pair for dimension 16 (2, 4, ..., 256)
benchmark = mbm.Mix3ToMix10Benchmark(16)
# OR load 'Early' images benchmark pair ('Early', 'Mid', 'Late')
# benchmark = mbm.CelebA64Benchmark('Early')
# Sample 32 random points from the benchmark measures
X = benchmark.input_sampler.sample(32)
Y = benchmark.output_sampler.sample(32)
# Compute the true forward map for points X
X.requires_grad_(True)
Y_true = benchmark.map_fwd(X, nograd=True)
Repository structure
All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/
). Auxilary source code is moved to .py
modules (src/
). Continuous benchmark pairs are stored as .pt
checkpoints (benchmarks/
).
Evaluation of Existing Solvers
We provide all the code to evaluate existing dual OT solvers on our benchmark pairs. The qualitative results are shown below. For quantitative results, see the paper.
Testing Existing Solvers On High-Dimensional Benchmarks
notebooks/MM_test_hd_benchmark.ipynb
-- testing [MM], [MMv2] solvers and their reversed versionsnotebooks/MMv1_test_hd_benchmark.ipynb
-- testing [MMv1] solvernotebooks/MM-B_test_hd_benchmark.ipynb
-- testing [MM-B] solvernotebooks/W2_test_hd_benchmark.ipynb
-- testing [W2] solver and its reversed versionnotebooks/QC_test_hd_benchmark.ipynb
-- testing [QC] solvernotebooks/LS_test_hd_benchmark.ipynb
-- testing [LS] solver
Testing Existing Solvers On Images Benchmark Pairs (CelebA 64x64 Aligned Faces)
notebooks/MM_test_images_benchmark.ipynb
-- testing [MM] solver and its reversed versionnotebooks/W2_test_images_benchmark.ipynb
-- testing [W2]notebooks/MM-B_test_images_benchmark.ipynb
-- testing [MM-B] solvernotebooks/QC_test_images_benchmark.ipynb
-- testing [QC] solver
[LS], [MMv2], [MMv1] solvers are not considered in this experiment.
Generative Modeling by Using Existing Solvers to Compute Loss
Warning: training may take several days before achieving reasonable FID scores!
notebooks/MM_test_image_generation.ipynb
-- generative modeling by [MM] solver or its reversed versionnotebooks/W2_test_image_generation.ipynb
-- generative modeling by [W2] solver
For [QC] solver we used the code from the official WGAN-QC repo.
Training Benchmark Pairs From Scratch
This code is provided for completeness and is not intended to be used to retrain existing benchmark pairs, but might be used as the base to train new pairs on new datasets. High-dimensional benchmak pairs can be trained from scratch. Training images benchmark pairs requires generator network checkpoints. We used WGAN-QC model to provide such checkpoints.
notebooks/W2_train_hd_benchmark.ipynb
-- training high-dimensional benchmark bairs by [W2] solvernotebooks/W2_train_images_benchmark.ipynb
-- training images benchmark bairs by [W2] solver
Credits
- Weights & Biases developer tools for machine learning;
- CelebA page with faces dataset and this page with its aligned 64x64 version;
- pytorch-fid repo to compute FID score;
- UNet architecture for transporter network;
- ResNet architectures for generator and discriminator;