Distributional Sliced-Wasserstein distance code

Related tags

Deep Learning DSW
Overview

Distributional Sliced Wasserstein distance

This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Generative Modeling". The work was done during the residency at VinAI Research, Hanoi, Vietnam.

Requirement

  • python3.6
  • pytorch 1.3
  • torchvision
  • numpy
  • tqdm

Train on MNIST and FMNIST

python mnist.py \
    --datadir='./' \
    --outdir='./result' \
    --batch-size=512 \
    --seed=16 \
    --p=2 \
    --lr=0.0005 \
    --dataset='MNIST'
    --model-type='DSWD'\
    --latent-size=32 \ 
model-type in (SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD|MGSWNN|JMGSWNN|MGSWD|JMGSWD)

Options for Sliced distances (number of projections used to approximate the distances)

--num-projection=1000

Options for Max Sliced-Wasserstein distance and Distributional distances (number of gradient steps for find the max slice or the optimal push-forward function):

--niter=10

Options for Distributional Sliced-Wasserstein Distance and Distributional Generalized Sliced-Wasserstein Distance (regularization strength)

--lam=10

Options for Generalized Wasserstein Distance (using circular function for Generalized Radon Transform)

--r=1000;\
--g='circular'

Train on CELEBA and CIFAR10 and LSUN

python main.py \
    --datadir='./' \
    --outdir='./result' \
    --batch-size=512 \
    --seed=16 \
    --p=2 \
    --lr=0.0005 \
    --model-type='DSWD'\
    --dataset='CELEBA'
    --latent-size=100 \ 
model-type in (SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER)

Options for Sliced distances (number of projections used to approximate the distances)

--num-projection=1000

Options for Max Sliced-Wasserstein distance and Distributional distances (number of gradient steps for find the max slice or the optimal push-forward function):

--niter=1

Options for Distributional Sliced-Wasserstein Distance and Distributional Generalized Sliced-Wasserstein Distance (regularization strength)

--lam=1

Options for Generalized Wasserstein Distance (using circular function for Generalized Radon Transform)

--r=1000;\
--g='circular'

Some generated images

MNIST generated images

MNIST

CELEBA generated images

MNIST

LSUN generated images

MNIST

You might also like...
Dynamical Wasserstein Barycenters for Time Series Modeling

Dynamical Wasserstein Barycenters for Time Series Modeling This is the code related for the Dynamical Wasserstein Barycenter model published in Neurip

An implementation of the [Hierarchical (Sig-Wasserstein) GAN] algorithm for large dimensional Time Series Generation
An implementation of the [Hierarchical (Sig-Wasserstein) GAN] algorithm for large dimensional Time Series Generation

Hierarchical GAN for large dimensional financial market data Implementation This repository is an implementation of the [Hierarchical (Sig-Wasserstein

The Pytorch code of
The Pytorch code of "Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification", CVPR 2022 (Oral).

DeepBDC for few-shot learning        Introduction In this repo, we provide the implementation of the following paper: "Joint Distribution Matters: Dee

Official implementation of "MetaSDF: Meta-learning Signed Distance Functions"

MetaSDF: Meta-learning Signed Distance Functions Project Page | Paper | Data Vincent Sitzmann*, Eric Ryan Chan*, Richard Tucker, Noah Snavely Gordon W

Blender scripts for computing geodesic distance
Blender scripts for computing geodesic distance

GeoDoodle Geodesic distance computation for Blender meshes Table of Contents Overivew Usage Implementation Overview This addon provides an operator fo

Distance Encoding for GNN Design

Distance-encoding for GNN design This repository is the official PyTorch implementation of the DEGNN and DEAGNN framework reported in the paper: Dista

This is an official implementation of the paper "Distance-aware Quantization", accepted to ICCV2021.

PyTorch implementation of DAQ This is an official implementation of the paper "Distance-aware Quantization", accepted to ICCV2021. For more informatio

This is a virtual picture dragging application. Users may virtually slide photos across the screen. The distance between the index and middle fingers determines the movement. Smaller distances indicate click and motion, whereas bigger distances indicate only hand movement. Implementation of Neural Distance Embeddings for Biological Sequences (NeuroSEED) in PyTorch
Implementation of Neural Distance Embeddings for Biological Sequences (NeuroSEED) in PyTorch

Neural Distance Embeddings for Biological Sequences Official implementation of Neural Distance Embeddings for Biological Sequences (NeuroSEED) in PyTo

Comments
  • Pytorch Version

    Pytorch Version

    Hi! Thank you for releasing the code for such an interesting project.

    I noticed your code expects an older version of pytorch. Would you consider a version that works with a newer pytorch (e.g., on a different branch)?

    opened by ttaa9 4
  • CUDA out of memory

    CUDA out of memory

    Thank you for sharing the codes of this interesting tool. I get the error below when running the code in Jupyter. However, this error does not appear when running on google colab.
    I created a new enviornment with Python 3.8.10 and installed torch==1.4.0 torchvision==0.5.0 numpy tqdm POT scikit-image. Any tips to tackle this error please? :) run mnist.py --datadir=./ --outdir=./result --batch-size=1 --seed=16 --p=2 --lr=0.05 --dataset=MNIST --model-type=DSWD --latent-size=32 --num-workers=0

    RuntimeError                              Traceback (most recent call last)
    ~\Sliced Gan\DSW-master\DSW-master\mnist.py in <module>
        327 
        328 if __name__ == "__main__":
    --> 329     main()
    
    ~\Sliced Gan\DSW-master\DSW-master\mnist.py in main()
        281                     fake = model.decoder(fixednoise_wd)
        282                     wd_list.append(compute_true_Wasserstein(data.to("cpu"), fake.to("cpu")))
    --> 283                     swd_list.append(sliced_wasserstein_distance(data, fake, 10000).item())
        284                     print("Iter:" + str(ite) + " WD: " + str(wd_list[-1]))
        285                     np.savetxt(model_dir + "/wd.csv", wd_list, delimiter=",")
    
    ~\Sliced Gan\DSW-master\DSW-master\utils.py in sliced_wasserstein_distance(first_samples, second_samples, num_projections, p, device)
         17     wasserstein_distance = torch.abs(
         18         (
    ---> 19             torch.sort(first_projections.transpose(0, 1), dim=1)[0]
         20             - torch.sort(second_projections.transpose(0, 1), dim=1)[0]
         21         )
    
    RuntimeError: CUDA out of memory. Tried to allocate 1.49 GiB (GPU 0; 4.00 GiB total capacity; 2.00 GiB already allocated; 935.90 MiB free; 2.02 GiB reserved in total by PyTorch)
    
    opened by SAH01-sys 2
  • With DSWD as a loss, should backprop go through the function f?

    With DSWD as a loss, should backprop go through the function f?

    Hi,

    Thank you for sharing your code. I am wondering if this line

    https://github.com/VinAIResearch/DSW/blob/d76ffb11f6da90c9a5aa74600dab76c17e82628f/utils.py#L126

    should be replaced by

    with torch.inference_mode():
        projections = f(pro)
    

    When updating a generative model that uses DSWD as a loss, do we need the computational graph of f that comes attached to projections? i.e., can projections be treated as a constant when backpropagating the DSWD loss through a generative model?

    opened by chmendoza 0
Owner
VinAI Research
VinAI Research
Hand-distance-measurement-game - Hand Distance Measurement Game

Hand Distance Measurement Game This is program is made to calculate the distance

Priyansh 2 Jan 12, 2022
A Distributional Approach To Controlled Text Generation

A Distributional Approach To Controlled Text Generation This is the repository code for the ICLR 2021 paper "A Distributional Approach to Controlled T

NAVER 102 Jan 7, 2023
A working implementation of the Categorical DQN (Distributional RL).

Categorical DQN. Implementation of the Categorical DQN as described in A distributional Perspective on Reinforcement Learning. Thanks to @tudor-berari

Florin Gogianu 98 Sep 20, 2022
Quantile Regression DQN a Minimal Working Example, Distributional Reinforcement Learning with Quantile Regression

Quantile Regression DQN Quantile Regression DQN a Minimal Working Example, Distributional Reinforcement Learning with Quantile Regression (https://arx

Arsenii Senya Ashukha 80 Sep 17, 2022
Code accompanying the paper "Wasserstein GAN"

Wasserstein GAN Code accompanying the paper "Wasserstein GAN" A few notes The first time running on the LSUN dataset it can take a long time (up to an

null 3.1k Jan 1, 2023
(NeurIPS 2020) Wasserstein Distances for Stereo Disparity Estimation

Wasserstein Distances for Stereo Disparity Estimation Accepted in NeurIPS 2020 as Spotlight. [Project Page] Wasserstein Distances for Stereo Disparity

Divyansh Garg 92 Dec 12, 2022
Implementation of Wasserstein adversarial attacks.

Stronger and Faster Wasserstein Adversarial Attacks Code for Stronger and Faster Wasserstein Adversarial Attacks, appeared in ICML 2020. This reposito

null 21 Oct 6, 2022
PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs

PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs This code aims to reproduce results obtained in the paper "Visual F

Orobix 93 Aug 17, 2022
Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN", accepted to ACM MM 2021 BNI Track.

RecycleD Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN

Yunan Zhu 23 Nov 5, 2022
A pytorch implementation of Paper "Improved Training of Wasserstein GANs"

WGAN-GP An pytorch implementation of Paper "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, SciPy, Matplotlib A recent NVIDIA GPU

Marvin Cao 1.4k Dec 14, 2022