A set of tests for evaluating large-scale algorithms for Wasserstein-2 transport maps computation.

Overview

Continuous Wasserstein-2 Benchmark

This is the official Python implementation of the NeurIPS 2021 paper Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark (paper on arxiv) by Alexander Korotin, Lingxiao Li, Aude Genevay, Justin Solomon, Alexander Filippov and Evgeny Burnaev.

The repository contains a set of continuous benchmark measures for testing optimal transport solvers for quadratic cost (Wasserstein-2 distance), the code for optimal transport solvers and their evaluation.

Citation

@article{korotin2021neural,
  title={Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark},
  author={Korotin, Alexander and Li, Lingxiao and Genevay, Aude and Solomon, Justin and Filippov, Alexander and Burnaev, Evgeny},
  journal={arXiv preprint arXiv:2106.01954},
  year={2021}
}

Pre-requisites

The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with

torch==1.3.0 torchvision==0.4.1

The code might not run as intended in newer torch versions.

Related repositories

Loading Benchmark Pairs

from src import map_benchmark as mbm

# Load benchmark pair for dimension 16 (2, 4, ..., 256)
benchmark = mbm.Mix3ToMix10Benchmark(16)
# OR load 'Early' images benchmark pair ('Early', 'Mid', 'Late')
# benchmark = mbm.CelebA64Benchmark('Early')

# Sample 32 random points from the benchmark measures
X = benchmark.input_sampler.sample(32)
Y = benchmark.output_sampler.sample(32)

# Compute the true forward map for points X
X.requires_grad_(True)
Y_true = benchmark.map_fwd(X, nograd=True)

Repository structure

All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/). Auxilary source code is moved to .py modules (src/). Continuous benchmark pairs are stored as .pt checkpoints (benchmarks/).

Evaluation of Existing Solvers

We provide all the code to evaluate existing dual OT solvers on our benchmark pairs. The qualitative results are shown below. For quantitative results, see the paper.

Testing Existing Solvers On High-Dimensional Benchmarks

  • notebooks/MM_test_hd_benchmark.ipynb -- testing [MM], [MMv2] solvers and their reversed versions
  • notebooks/MMv1_test_hd_benchmark.ipynb -- testing [MMv1] solver
  • notebooks/MM-B_test_hd_benchmark.ipynb -- testing [MM-B] solver
  • notebooks/W2_test_hd_benchmark.ipynb -- testing [W2] solver and its reversed version
  • notebooks/QC_test_hd_benchmark.ipynb -- testing [QC] solver
  • notebooks/LS_test_hd_benchmark.ipynb -- testing [LS] solver

Testing Existing Solvers On Images Benchmark Pairs (CelebA 64x64 Aligned Faces)

  • notebooks/MM_test_images_benchmark.ipynb -- testing [MM] solver and its reversed version
  • notebooks/W2_test_images_benchmark.ipynb -- testing [W2]
  • notebooks/MM-B_test_images_benchmark.ipynb -- testing [MM-B] solver
  • notebooks/QC_test_images_benchmark.ipynb -- testing [QC] solver

[LS], [MMv2], [MMv1] solvers are not considered in this experiment.

Generative Modeling by Using Existing Solvers to Compute Loss

Warning: training may take several days before achieving reasonable FID scores!

  • notebooks/MM_test_image_generation.ipynb -- generative modeling by [MM] solver or its reversed version
  • notebooks/W2_test_image_generation.ipynb -- generative modeling by [W2] solver

For [QC] solver we used the code from the official WGAN-QC repo.

Training Benchmark Pairs From Scratch

This code is provided for completeness and is not intended to be used to retrain existing benchmark pairs, but might be used as the base to train new pairs on new datasets. High-dimensional benchmak pairs can be trained from scratch. Training images benchmark pairs requires generator network checkpoints. We used WGAN-QC model to provide such checkpoints.

  • notebooks/W2_train_hd_benchmark.ipynb -- training high-dimensional benchmark bairs by [W2] solver
  • notebooks/W2_train_images_benchmark.ipynb -- training images benchmark bairs by [W2] solver

Credits

You might also like...
A pytorch implementation of Paper
A pytorch implementation of Paper "Improved Training of Wasserstein GANs"

WGAN-GP An pytorch implementation of Paper "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, SciPy, Matplotlib A recent NVIDIA GPU

Code accompanying the paper
Code accompanying the paper "Wasserstein GAN"

Wasserstein GAN Code accompanying the paper "Wasserstein GAN" A few notes The first time running on the LSUN dataset it can take a long time (up to an

Dynamical Wasserstein Barycenters for Time Series Modeling

Dynamical Wasserstein Barycenters for Time Series Modeling This is the code related for the Dynamical Wasserstein Barycenter model published in Neurip

This is our ARTS test set, an enriched test set to probe Aspect Robustness of ABSA.
This is our ARTS test set, an enriched test set to probe Aspect Robustness of ABSA.

This is the repository for our 2020 paper "Tasty Burgers, Soggy Fries: Probing Aspect Robustness in Aspect-Based Sentiment Analysis". Data We provide

 Open-Set Recognition: A Good Closed-Set Classifier is All You Need
Open-Set Recognition: A Good Closed-Set Classifier is All You Need

Open-Set Recognition: A Good Closed-Set Classifier is All You Need Code for our paper: "Open-Set Recognition: A Good Closed-Set Classifier is All You

Script that receives an Image (original) and a set of images to be used as
Script that receives an Image (original) and a set of images to be used as "pixels" in reconstruction of the Original image using the set of images as "pixels"

picinpics Script that receives an Image (original) and a set of images to be used as "pixels" in reconstruction of the Original image using the set of

Open-AI's DALL-E for large scale training in mesh-tensorflow.

DALL-E in Mesh-Tensorflow [WIP] Open-AI's DALL-E in Mesh-Tensorflow. If this is similarly efficient to GPT-Neo, this repo should be able to train mode

Apache Spark - A unified analytics engine for large-scale data processing

Apache Spark Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an op

This is a Pytorch implementation of the paper: Self-Supervised Graph Transformer on Large-Scale Molecular Data.

This is a Pytorch implementation of the paper: Self-Supervised Graph Transformer on Large-Scale Molecular Data.

Comments
  • Could you tell me how to test these models after training?

    Could you tell me how to test these models after training?

    Hello iamalexkorotin :-) I'm an undergraduate student majored in computer science. Before mentioning what I need, firstly your code was very helpful for me. I'm working on a problem which is how to solve optimal transport efficiently. I find that in your notebooks, both train_hd and test_hd containing the training process. I wonder how can I test the model after training, which is to make the model generate a transport map directly. And could you show me the test code? Thanks very much!

    opened by benmagnifico 1
Owner
Alexander
PhD Student (Computer Science) at Skolkovo University of Science and Technology (Moscow, Russia)
Alexander
Automatically download the cwru data set, and then divide it into training data set and test data set

Automatically download the cwru data set, and then divide it into training data set and test data set.自动下载cwru数据集,然后分训练数据集和测试数据集

null 6 Jun 27, 2022
Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR 2022)

Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR2022)[paper] Authors: Chenhang He, Ruihuang Li, Shuai Li, L

Billy HE 141 Dec 30, 2022
SLIDE : In Defense of Smart Algorithms over Hardware Acceleration for Large-Scale Deep Learning Systems

The SLIDE package contains the source code for reproducing the main experiments in this paper. Dataset The Datasets can be downloaded in Amazon-

Intel Labs 72 Dec 16, 2022
A large dataset of 100k Google Satellite and matching Map images, resembling pix2pix's Google Maps dataset.

Larger Google Sat2Map dataset This dataset extends the aerial ⟷ Maps dataset used in pix2pix (Isola et al., CVPR17). The provide script download_sat2m

null 34 Dec 28, 2022
GeneDisco is a benchmark suite for evaluating active learning algorithms for experimental design in drug discovery.

GeneDisco is a benchmark suite for evaluating active learning algorithms for experimental design in drug discovery.

null 22 Dec 12, 2022
Distributional Sliced-Wasserstein distance code

Distributional Sliced Wasserstein distance This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Genera

VinAI Research 39 Jan 1, 2023
(NeurIPS 2020) Wasserstein Distances for Stereo Disparity Estimation

Wasserstein Distances for Stereo Disparity Estimation Accepted in NeurIPS 2020 as Spotlight. [Project Page] Wasserstein Distances for Stereo Disparity

Divyansh Garg 92 Dec 12, 2022
Implementation of Wasserstein adversarial attacks.

Stronger and Faster Wasserstein Adversarial Attacks Code for Stronger and Faster Wasserstein Adversarial Attacks, appeared in ICML 2020. This reposito

null 21 Oct 6, 2022
PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs

PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs This code aims to reproduce results obtained in the paper "Visual F

Orobix 93 Aug 17, 2022
Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN", accepted to ACM MM 2021 BNI Track.

RecycleD Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN

Yunan Zhu 23 Nov 5, 2022